# root/trunk/dev/ideas/synapses/spikequeue.py@2683

Revision 2683, 4.3 KB (checked in by romainbrette, 21 months ago)

Synapses coming to life!

Line
1"""
2Spike queues following BEP-21
3"""
4from brian import *
5from time import time
6
7# This is a 2D circular array
8class SpikeQueue(object):
9    '''
10    A spike queue, implemented as a circular 2D array.
11
12    * Initialize with the number of timesteps and the maximum number of spikes
13      in each timestep: queue=SpikeQueue(nsteps,maxevents)
14    * At the beginning or end of each timestep: queue.next()
15    * To get all spikes: events=queue.peek()
16      It returns the indexes of all synapses receiving an event.
17    * When a presynaptic spike is emitted:
18      queue.insert(delay,offset,target)
19      where delay is the array of synaptic delays of targets in timesteps,
20      offset is the array of offsets within each timestep,
21      target is the array of synapse indexes of targets.
22      The offset is used to solve the problem of multiple synapses with the
23      same delay. For example, if there are two target synapses 7 and 9 with delay
24      2 timesteps: queue.insert([2,2],[0,1],[7,9])
25
26    Thus, offsets are determined by delays. They could be either precalculated
27    (faster), or determined at run time (saves memory). Note that if they
28    are determined at run time, then it may be possible to also vectorize over
29    presynaptic spikes.
30    '''
31    def __init__(self,nsteps,maxevents):
32        # number of time steps, maximum number of spikes per time step
33        self.X=zeros((nsteps,maxevents),dtype=int) # target synapses
34        self.X_flat=self.X.reshape(nsteps*maxevents,)
35        self.currenttime=0
36        self.n=zeros(nsteps,dtype=int) # number of events in each time step
37
38    def next(self):
39        # Advance by one timestep
40        self.n[self.currenttime]=0 # erase
41        self.currenttime=(self.currenttime+1) % len(self.n)
42
43    def peek(self):
44        # Events in the current timestep
45        return self.X[self.currenttime,:self.n[self.currenttime]]
46
47    def offsets(self,delay):
48        # Calculates offsets corresponding to a delay array
49        # That's not a very efficient way to do it
50        # (it's O(n*log(n)))
51        # (not tested!)
52        I = argsort(delay)
53        xs = delay[I]
54        J = xs[1:]!=xs[:-1]
55        #K = xs[1:]==xs[:-1]
56        A = hstack((0, cumsum(J)))
57        #B = hstack((0, cumsum(K)))
58        B = hstack((0, cumsum(-J)))
59        BJ = hstack((0, B[J]))
60        ei = B-BJ[A]
61        ofs = zeros_like(delay)
62        ofs[I] = ei
63        return ofs
64
65    def insert(self,delay,offset,target):
66        # Vectorized insertion of spike events
67        # delay = delay in timestep
68        # offset = offset within timestep
69        # target = target synaptic index
70        timesteps=(self.currenttime+delay) % len(self.n)
71        self.X_flat[(self.currenttime*self.X.shape[1]+offset+\
72                     self.n[timesteps])\
73                     % len(self.X)]=target
74        # Update the size of the stacks
75        self.n[timesteps]+=offset+1 # that's a trick
76        # There should a re-sizing operation, if overflow
77
78'''
79The connection has arrays of synaptic variables (same as state matrix of
80neuron groups). Two synaptic variables are the index of the postsynaptic neuron
81and of the presynaptic neuron (i,j). (int32 or int16).
82
83In addition, the connection must have, for each presynaptic neuron:
84* list of target synapses (int32)
85* corresponding delays in timesteps (int16)
86* corresponding offsets (int16 is probably sufficient, or less)
87
88These types (int32 etc) could be determined at construction time, or
89at the time of conversion construction->connection (run time).
90
91Same thing for postsynaptic neuron (for STDP)
92This could also be determined at run time (depending on whether post=None or not)
93
94Total memory:
95* number of synapses * 12 * 2 (if bidirectional)
96+ synaptic variables (weights)
97'''
98
99if __name__=='__main__':
100    queue=SpikeQueue(5,30)
101    Nsynapses=4000*80 # in the CUBA example
102    nspikes=160
103    delays=randint(160,size=nspikes) # average number of spikes per dt in CUBA
104    targets=randint(Nsynapses,size=nspikes)
105    #print queue.offsets(delays)
106    t1=time()
107    for _ in range(10000): # 10000 timesteps per second
108        d=queue.offsets(delays)
109        queue.insert(delays,d,targets)
110        queue.next()
111        events=queue.peek()
112    t2=time()
113    print t2-t1
Note: See TracBrowser for help on using the browser.