root/trunk/dev/ideas/synapses/spikequeue.py @ 2909

Revision 2909, 8.2 KB (checked in by romainbrette, 17 months ago)

Synapses: basic construction working

Line 
1"""
2Spike queues following BEP-21.
3
4Notes
5-----
6A SpikeQueue object will always be attached to a Synapses object. One important point is that delays and
7synapse indexes can be set after the objects are created. Therefore, they cannot be passed at initialization time,
8and this is why the Synapses object is passed.
9To save one indirection, what should be stored is a view on the delay and synapse arrays.
10
11The object is a SpikeMonitor on the source NeuronGroup. When it is called, spikes are fetched from the NeuronGroup
12into the queue. The way it is currently done is highly inefficient, because it only uses the following mappings
13    synapse -> delay
14    synapse -> presynaptic i
15but for efficiency, we need the mappings:
16    presynaptic i -> synapse
17    presynaptic i -> delay
18
19We will still need the mapping: synapse -> presynaptic i
20but not: synapse -> delay
21
22Actually, does this need to be a Brian object? It could directly be called by
23Synapses.
24
25** There is no resizing for the maximum delay **
26"""
27from brian import * # remove this
28from brian.stdunits import ms
29
30INITIAL_MAXSPIKESPER_DT = 1
31# This is a 2D circular array, but also a SpikeMonitor
32
33class SpikeQueue(SpikeMonitor):
34    '''
35    * Initialization *
36
37    Initialized with a source NeuronGroup, a Synapses object (from which it fetches the delays), a maximum delay
38   
39    Arguments
40    ``source'' NeuronGroup that is monitored
41    ``synapses'' List of arrays of synapse indexes
42    ``delays'' Array of delays corresponding to synapse indexes
43
44    Keywords
45    ``max_delay'' in seconds
46    ``maxevents'' Maximum initial number of events in each timestep. Notice that the structure will grow dynamically of there are more events than that, so you shouldn't bother.
47
48
49    * Circular 2D array structure *
50   
51    A spike queue is implemented as a circular 2D array.
52   
53    * At the beginning or end of each timestep: queue.next()
54    * To get all spikes: events=queue.peek()
55      It returns the indexes of all synapses receiving an event.
56    * When a presynaptic spike is emitted, the following is executed:
57      queue.insert(delay,offset,target)
58      where delay is the array of synaptic delays of targets in timesteps,
59      offset is the array of offsets within each timestep,
60      target is the array of synapse indexes of targets.
61      The offset is used to solve the problem of multiple synapses with the
62      same delay. For example, if there are two target synapses 7 and 9 with delay
63      2 timesteps: queue.insert([2,2],[0,1],[7,9])
64   
65    Thus, offsets are determined by delays. They could be either precalculated
66    (faster), or determined at run time (saves memory). Note that if they
67    are determined at run time, then it may be possible to also vectorize over
68    presynaptic spikes.
69   
70    * SpikeMonitor structure *
71   
72    It automatically updates the underlying structure by instantiating the propagate() method of the SpikeMonitor
73   
74    Ideas:
75    ------
76    * remove the max_delay keyword and have the structure created with another
77      method (at run time)
78    '''
79    def __init__(self, source, synapses, delays,
80                 max_delay = 0*ms, maxevents = INITIAL_MAXSPIKESPER_DT):
81        '''
82        TODO:
83        * precompute offsets
84        * make it work for both pre/post
85        * either source or synapses is not useful, no?
86        '''
87        # SpikeMonitor structure
88        self.source = source #NeuronGroup
89        self.delays = delays
90        self.synapses = synapses
91       
92        self.max_delay = max_delay # do we need this?
93        nsteps = int(np.floor((max_delay)/(self.source.clock.dt)))+1
94
95        # number of time steps, maximum number of spikes per time step
96        self.X = zeros((nsteps, maxevents), dtype = self.synapses[0].dtype) # target synapses
97        self.X_flat = self.X.reshape(nsteps*maxevents,)
98        self.currenttime = 0
99        self.n = zeros(nsteps, dtype = int) # number of events in each time step
100       
101        self._offsets = None # precalculated offsets
102       
103        super(SpikeQueue, self).__init__(source, 
104                                         record = False)
105
106    ################################ SPIKE QUEUE DATASTRUCTURE ######################
107    def next(self):
108        # Advance by one timestep
109        self.n[self.currenttime]=0 # erase
110        self.currenttime=(self.currenttime+1) % len(self.n)
111       
112    def peek(self):
113        # Events in the current timestep       
114        return self.X[self.currenttime,:self.n[self.currenttime]]
115   
116    def precompute_offsets(self):
117        #t0 = time.time()
118        self._offsets=[]
119        for i in range(len(self.synapses)):
120            delays=self.delays[self.synapses[i].data]
121            self._offsets.append(self.offsets(delays))
122        #log_debug('spikequeue.offsets', 'Offsets computed in '+str(time.time()-t0))
123   
124    def offsets(self, delay):
125        '''
126        Calculates offsets corresponding to a delay array
127        '''
128        I = argsort(delay)
129        xs = delay[I]
130        J = xs[1:]!=xs[:-1]
131        #K = xs[1:]==xs[:-1]
132        A = hstack((0, cumsum(J)))
133        #B = hstack((0, cumsum(K)))
134        B = hstack((0, cumsum(-J)))
135        BJ = hstack((0, B[J]))
136        ei = B-BJ[A]
137        ofs = zeros_like(delay)
138        ofs[I] = array(ei,dtype=ofs.dtype) # maybe types should be signed?
139        return ofs
140       
141    def insert(self, delay, offset, target):
142        # Vectorized insertion of spike events
143        # delay = delay in timestepp
144        # offset = offset within timestep
145        # target = target synaptic index
146       
147        timesteps = (self.currenttime + delay) % len(self.n)
148       
149        # Compute new stack sizes:
150        old_nevents = self.n[timesteps].copy() # because we need this for the final assignment, but we need to precompute the  new one to check for overflow
151        self.n[timesteps] += offset+1 # that's a trick (to update stack size), plus we pre-compute it to check for overflow
152       
153        m = max(self.n[timesteps])+1 # If overflow, then at least one self.n is bigger than the size
154        if (m >= self.X.shape[1]):
155            self.resize(m+1) # was m previously (not enough)
156       
157        self.X_flat[timesteps*self.X.shape[1]+offset+old_nevents]=target
158        # Old code seemed wrong:
159        #self.X_flat[(self.currenttime*self.X.shape[1]+offset+\
160        #             old_nevents)\
161        #             % len(self.X)]=target
162       
163    def resize(self, maxevents):
164        '''
165        Resizes the underlying data structure (number of columns = spikes per dt).
166        max events will be rounded to the closest power of 2.
167        '''
168        # old and new sizes
169        old_maxevents = self.X.shape[1]
170        new_maxevents = 2**ceil(log2(maxevents)) # maybe 2 is too large
171        # new array
172        newX = zeros((self.X.shape[0], new_maxevents), dtype = self.X.dtype)
173        newX[:, :old_maxevents] = self.X[:, :old_maxevents] # copy old data
174       
175        self.X = newX
176        self.X_flat = self.X.reshape(self.X.shape[0]*new_maxevents,)
177        #log_debug('spikequeue', 'Resizing SpikeQueue')
178       
179    def propagate(self, spikes):
180        if len(spikes):
181            if self._offsets is None: # vectorise over synaptic events
182                synaptic_events=hstack([self.synapses[i].data for i in spikes])
183                if len(synaptic_events):
184                    delay = self.delays[synaptic_events] # but it could be post!
185                    offsets = self.offsets(delay)
186                    self.insert(delay, offsets, synaptic_events)
187            else: # offsets are precomputed
188                for i in spikes:
189                    synaptic_events=self.synapses[i].data # assuming a dynamic array: could change at run time?   
190                    if len(synaptic_events):
191                        delay = self.delays[synaptic_events]
192                        offsets = self._offsets[i]
193                        self.insert(delay, offsets, synaptic_events)
194
195    ######################################## UTILS   
196    def plot(self, display = True):
197        for i in range(self.X.shape[0]):
198            idx = (i + self.currenttime ) % self.X.shape[0]
199            data = self.X[idx, :self.n[idx]]
200            plot(idx * ones(len(data)), data, '.')
201        if display:
202            show()
203
204if __name__=='__main__':
205    pass
Note: See TracBrowser for help on using the browser.