Changeset 199

Show
Ignore:
Timestamp:
09/18/08 17:31:59 (4 months ago)
Author:
pierre
Message:

Start a reorganisation of NeuroTools, in order to have a clean and easily understandable package. The aim is to comment the functions, to make them as general as possible and to established the gap between pyNN and NeuroTools

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • branches/cleanup/src/analysis.py

    r188 r199  
    1010 
    1111 
     12def ccf(x, y, axis=None): 
     13    """Computes the cross-correlation function of two series `x` and `y`. 
     14Note that the computations are performed on anomalies (deviations from 
     15average). 
     16Returns the values of the cross-correlation at different lags. 
     17Lags are given as [0,1,2,...,n,n-1,n-2,...,-2,-1] (not any more) 
     18 
     19:Parameters: 
     20    `x` : 1D MaskedArray 
     21        Time series. 
     22    `y` : 1D MaskedArray 
     23        Time series. 
     24    `axis` : integer *[None]* 
     25        Axis along which to compute (0 for rows, 1 for cols). 
     26        If `None`, the array is flattened first. 
     27    """ 
     28    assert(x.ndim == y.ndim, "Inconsistent shape !") 
     29#    assert(x.shape == y.shape, "Inconsistent shape !") 
     30    if axis is None: 
     31        if x.ndim > 1: 
     32            x = x.ravel() 
     33            y = y.ravel() 
     34        npad = x.size + y.size 
     35        xanom = (x - x.mean(axis=None)) 
     36        yanom = (y - y.mean(axis=None)) 
     37        Fx = numpy.fft.fft(xanom, npad, ) 
     38        Fy = numpy.fft.fft(yanom, npad, ) 
     39        iFxy = numpy.fft.ifft(Fx.conj()*Fy).real 
     40        varxy = numpy.sqrt(numpy.inner(xanom,xanom) * numpy.inner(yanom,yanom)) 
     41    else: 
     42        npad = x.shape[axis] + y.shape[axis] 
     43        if axis == 1: 
     44            if x.shape[0] != y.shape[0]: 
     45                raise ValueError, "Arrays should have the same length!" 
     46            xanom = (x - x.mean(axis=1)[:,None]) 
     47            yanom = (y - y.mean(axis=1)[:,None]) 
     48            varxy = numpy.sqrt((xanom*xanom).sum(1) * (yanom*yanom).sum(1))[:,None] 
     49        else: 
     50            if x.shape[1] != y.shape[1]: 
     51                raise ValueError, "Arrays should have the same width!" 
     52            xanom = (x - x.mean(axis=0)) 
     53            yanom = (y - y.mean(axis=0)) 
     54            varxy = numpy.sqrt((xanom*xanom).sum(0) * (yanom*yanom).sum(0)) 
     55        Fx = numpy.fft.fft(xanom, npad, axis=axis) 
     56        Fy = numpy.fft.fft(yanom, npad, axis=axis) 
     57        iFxy = numpy.fft.ifft(Fx.conj()*Fy,n=npad,axis=axis).real 
     58    # We juste turn the lags into correct positions: 
     59    iFxy = numpy.concatenate((iFxy[len(iFxy)/2:len(iFxy)],iFxy[0:len(iFxy)/2])) 
     60    return iFxy/varxy 
     61 
     62 
     63 
     64 
     65 
     66 
     67 
     68 
     69 
     70 
    1271def record(output, cfilename = 'SpikeTrainPlay.wav', fs=44100, enc = 'pcm26'): 
    1372    """ record the 'sound' produced by a neuron. Takes a spike train as the 
    1473    output. 
    15  
    1674    >>> record(my_spike_train) 
    17  
    1875    """ 
    1976 
     
    2481    (trace,time) = numpy.histogram(output.spike_times*1000., fs*simtime_seconds) 
    2582 
    26  
    2783    # TODO convolve with proper spike... 
    2884    spike = numpy.ones((fs/1000.,)) # one ms 
    29  
    3085    trace = numpy.convolve(trace, spike, mode='same')#/2.0 
    3186    trace /= numpy.abs(trace).max() * 1.1 
    32  
    33     from scikits.audiolab import wavwrite 
     87    try: 
     88        from scikits.audiolab import wavwrite 
     89    except ImportError: 
     90        print "You need the scikits.audiolab package to produce sounds !" 
    3491    wavwrite(trace, cfilename, fs = fs, enc = enc) 
    3592 
     93 
     94 
    3695def play(output): 
    3796    """ 
    3897    plays a spike list to the audio output 
    39  
    4098    play(spike_list) where spike_list is a spike_list object 
    41  
    4299    see playing_with_simple_single_neuron.py for a sample use 
    43  
    44100    >>> play(my_spike_train) 
    45  
    46101    TODO: make it possible to play multiple spike trains in stereo 
    47102    """ 
  • branches/cleanup/src/plotting.py

    r186 r199  
    11 
    2 import numpy #, pylab 
    3 import sys 
     2import numpy, pylab, sys 
    43from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 
    54from matplotlib.figure import Figure 
     
    3837 
    3938 
    40 def raster_plot(spike_list,output=None):# limits of the plot 
    41     import pylab 
    42     DATA=spike_list.as_list_id_list_time() 
    43     pylab.plot(DATA[1],DATA[0],'.') 
    44     pylab.ylabel('neuron ID') 
    45     pylab.xlabel('time (s)') 
    46     pylab.axis([spike_list.t_start, spike_list.t_stop, 0, spike_list.N]) 
    47     if not(output==None): 
    48         pylab.savefig(output) 
    49     #else: 
    50         #pylab.show() 
    51  
    52  
    53  
    54  
    5539def set_frame(ax,boollist,linewidth=2): 
    5640    assert len(boollist) == 4 
     
    6549            if draw: 
    6650                ax.add_line(side) 
     51 
     52 
    6753 
    6854class SimpleMultiplot(object): 
  • branches/cleanup/src/signals.py

    r188 r199  
    44""" 
    55 
    6 import numpy 
    7 import os 
     6import numpy, os, logging, re 
     7from NeuroTools import analysis 
     8 
    89try : 
    910    import pylab 
    10 except Exception: 
    11     print "Warning: pylab not present" 
    12 from NeuroTools import analysis 
    13 import logging 
     11except ImportError: 
     12    print "Warning: pylab not detected, plots will be disabled" 
     13 
     14 
     15DEFAULT_BUFFER_SIZE = 10000 
    1416 
    1517class SpikeTrain(object): 
    1618    """ 
    17     This class defines a spike train as a list of the events times. 
     19    This class defines a spike train as a list of times events. 
    1820 
    1921    Event times are given in a list (sparse representation) in milliseconds. 
    2022 
    2123    >>> s1 = SpikeTrain([0.0, 0.1, 0.2, 0.5]) 
    22     >>> s2 = SpikeTrain(numpy.array([0, 1, 2, 5]), dt=0.1) 
     24    >>> s2 = SpikeTrain([0, 1, 2, 5], dt=0.1) 
    2325    >>> assert all(s1.spike_times == s2.spike_times) 
    24     >>> print s1.format(relative=True) 
    25     [ 0.   0.1  0.1  0.3] 
    2626    >>> s1.isi() 
    2727    array([ 0.1,  0.1,  0.3]) 
     
    3131    0.565685424949 
    3232    """ 
    33     # TODO in the definition, should a spike train be ordered? 
    34  
    35     # TODO : should we handle different population shapes differently? 
    3633 
    3734    ####################################################################### 
     
    5451 
    5552        if dt is not None: 
    56             if dt<=0: 
     53            if dt <= 0: 
    5754                raise ValueError("dt must be greater than zero") 
    5855            #self.spike_times *= dt 
    5956 
    60         self.dt = dt 
    61         self.t_start = t_start 
    62         self.t_stop  = t_stop 
    63  
     57        self.dt          = dt 
     58        self.t_start     = t_start 
     59        self.t_stop      = t_stop 
    6460        self.spike_times = numpy.array(spike_times, 'float') 
    6561 
     
    6763        # the spikes with t >= t_start 
    6864        if self.t_start is not None: 
    69             idx = numpy.where(self.spike_times >= self.t_start)[0] 
    70             self.spike_times = self.spike_times[idx] 
    71             assert numpy.all(self.spike_times >= self.t_start), \ 
    72               "Spike times out of range (t_start=%s, min(spike_times)=%s" % (self.t_start,self.spike_times.min()) 
     65            self.spike_times = numpy.extract((self.spike_times >= self.t_start), self.spike_times) 
    7366 
    7467        # If t_stop is not None, we resize the spike_train keeping only 
    7568        # the spikes with t <= t_stop 
    7669        if self.t_stop is not None: 
    77             idx = numpy.where(self.spike_times <= self.t_stop)[0] 
    78             self.spike_times = self.spike_times[idx] 
    79             assert numpy.all(self.spike_times <= self.t_stop), \ 
    80                 "Spike times out of range (t_stop=%s, max(spike_times)=%s" % (self.t_stop,self.spike_times.max()) 
     70            self.spike_times = numpy.extract((self.spike_times <= self.t_stop), self.spike_times) 
    8171 
    8272        # Here we deal with the t_start and t_stop values if the SpikeTrain 
     
    10999 
    110100        if self.t_start >= self.t_stop : 
    111             raise Exception("Incompatible time interval for the creation of the SpikeTrain. t_start=%s, t_stop=%s" % (self.t_start, self.t_stop)) 
     101            raise Exception("Incompatible time interval : t_start = %s, t_stop = %s" % (self.t_start, self.t_stop)) 
    112102        if self.t_start < 0: 
    113103            raise ValueError("t_start must not be negative") 
     
    120110    def __len__(self): 
    121111        return len(self.spike_times) 
     112     
     113    def __getslice(self, i, j): 
     114        return self.spike_times[i:j] 
     115     
     116    def __getdisplay__(self,display): 
     117        if display is False: 
     118            return None 
     119        elif display is True: 
     120            pylab.figure() 
     121            return pylab 
     122        else: 
     123            return display 
     124     
     125    def __labels__(self, subplot, xlabel, ylabel): 
     126        if hasattr(subplot, 'xlabel'): 
     127            subplot.xlabel(xlabel, size="large") 
     128            subplot.ylabel(ylabel, size="large") 
     129        else: 
     130            subplot.set_xlabel(xlabel, size="large") 
     131            subplot.set_ylabel(ylabel, size="large") 
    122132 
    123133    def duration(self): 
     
    133143        spike_times.sort() 
    134144 
    135         if relative and len(spike_times)>0: 
     145        if relative and len(spike_times) > 0: 
    136146            spike_times[1:] = spike_times[1:] - spike_times[:-1] 
    137147 
     
    143153        return spike_times 
    144154 
     155 
     156 
    145157    ####################################################################### 
    146158    ## Analysis methods that can be applied to a SpikeTrain object       ## 
    147159    ####################################################################### 
     160     
    148161    def isi(self): 
    149162        # TODO this needs some thinking to know how to handle the border, in particular the 
    150163        # first spike and t_start 
    151         return self.format(relative=True, quantized=False)[1:] 
     164        return numpy.diff(self.spike_times) 
    152165 
    153166    # Returns the mean firing rate of the SpikeTrain 
    154167    def mean_rate(self, t_start=None, t_stop=None): 
    155         """ Mean firing rate between t_start and t_stop in Hz 
    156  
    157         NOTE: avoided calling where when defaults settings t_start=self.t_start, t_stop=self.t_stop 
    158         """ 
    159         if (t_start==None) & (t_stop==None): 
    160             t_start=self.t_start 
    161             t_stop=self.t_stop 
    162             idx = self.spike_times 
    163         else: 
    164             if t_start==None: t_start=self.t_start 
    165             if t_stop==None: t_stop=self.t_stop 
     168        """  
     169        Mean firing rate between t_start and t_stop in Hz 
     170        By default, if t_start and t_stop are not defined, we used those of the SpikeTrain object 
     171        """ 
     172        if (t_start == None) & (t_stop == None): 
     173            t_start = self.t_start 
     174            t_stop  = self.t_stop 
     175            idx     = self.spike_times 
     176        else: 
     177            if t_start == None: t_start=self.t_start 
     178            if t_stop == None: t_stop=self.t_stop 
    166179            idx = numpy.where((self.spike_times >= t_start) & (self.spike_times <= t_stop))[0] 
    167180        return 1000.*len(idx)/(t_stop-t_start) 
     
    179192        Poisson-type behavior. As a measure for irregularity in the network one 
    180193        can use the average irregularity across all neurons. 
    181  
    182         TODO: is it useful to get the std of CV? 
     194         
     195        See also 
     196            SpikeList.cv_isi 
     197 
    183198        """ 
    184199        isi = self.isi() 
     
    190205 
    191206    def fano_factor_isi(self): 
    192         """ returns the fano factor of this spike trains ISI (see SpikeList.fano_factor)""" 
     207        """  
     208        Return the fano factor of this spike trains ISI  
     209         
     210        See also 
     211            SpikeList.fano_factor 
     212        """ 
    193213        isi = self.isi() 
    194214        if len(isi) > 0: 
    195             #firing_rate = self.time_histogram(time_bin,False)       
    196215            fano = numpy.var(isi)/numpy.mean(isi) 
    197216            return fano 
     
    205224        return numpy.arange(self.t_start, self.t_stop, time_bin) 
    206225 
    207     def raster_plot(self, t_start=None, t_stop=None, color='b'): 
     226 
     227    def raster_plot(self, t_start=None, t_stop=None, color='b', display=True): 
    208228        """ 
    209229        Generate a raster plot with the SpikeTrain in a subwindow of interest, 
     
    211231        of the SpikeTrain objects. If not defined, we use the one of the SpikeTrain 
    212232        object 
     233         
     234        - color is a string color like in pylab plots 'k','r' 
     235        - display can be either a boolean (to ensure backward compatibilities) or a figure 
     236        object with a plot() method. 
     237         
     238        See also 
     239            SpikeList.raster_plot 
    213240        """ 
    214241        if t_start is None: 
     
    216243        if t_stop is None: 
    217244            t_stop = self.t_stop 
    218         idx = numpy.where((self.spike_times >= t_start) & (self.spike_times <= t_stop))[0] 
    219         spikes = self.spike_times[idx] 
    220         if len(spikes) > 0: 
    221             pylab.figure() 
    222             pylab.scatter(spikes,numpy.ones(len(spikes)), c=color) 
    223             pylab.xlabel("Time (ms)", size="x-large") 
    224             pylab.ylabel("Neuron #", size="x-large") 
     245        spikes = numpy.extract((self.spike_times >= t_start) & (self.spike_times <= t_stop), self.spike_times) 
     246        subplot = self.__getdisplay__(display) 
     247        if not subplot: 
     248            return spikes 
     249        else: 
     250            if len(spikes) > 0: 
     251                xlabel = "Time (ms)" 
     252                ylabel = "Neurons #" 
     253                self.__labels__(subplot, xlabel, ylabel) 
     254                subplot.scatter(spikes,numpy.ones(len(spikes)), c=color) 
     255                 
     256 
    225257 
    226258    # Method to sort a SpikeTrain 
    227259    def sort_by_time(self): 
     260        """ 
     261        Sort the spike times according to the time axis 
     262        """ 
    228263        self.spike_time = sort(self.spike_time) 
    229264 
    230265    def subSpikeTrain(self, t_start, t_stop): 
    231         """ Returns a spike train sliced between t_start and t_stop 
     266        """  
     267        Return a spike train sliced between t_start and t_stop 
    232268        t_start and t_stop may either be single values or sequences of start 
    233269        and stop times. 
     
    269305        if self.t_start != 0: 
    270306            self.spike_times -= self.t_start 
    271             self.t_stop -= self.t_start 
    272             self.t_start = 0.0 
     307            self.t_stop      -= self.t_start 
     308            self.t_start      = 0.0 
    273309 
    274310    def tuning_curve(self, var_array, normalized=False, method='sum'): 
     
    358394         
    359395 
     396 
    360397class SpikeList(object): 
    361398    """ 
     
    363400 
    364401    >>> sl = SpikeList(3, [(0, 0.1), (1, 0.1), (0, 0.2)]) 
    365     >>> type( sl.spiketrains[0] ) 
     402    >>> type( sl[0] ) 
    366403    <type SpikeTrain> 
    367     >>> sl.spiketrains[0].spike_times 
     404    >>> sl[0].spike_times 
    368405    array([ 0.1,  0.2]) 
    369406    >>> sl.as_ids_times() 
     
    390427        `spikes` is a list/tuple of (id,t) tuples (id in id_list) 
    391428        `id_list` is the list of ids of all recorded cells (needed for silent cells) 
    392         If `dt`is specified, the time values should be ints, 
     429        If `dt` is specified, the time values should be ints, 
    393430        and will be multiplied by `dt`, otherwise time values should be floats. 
    394431        If `t_start` and `t_stop` are not specified, they are inferred from the data. 
    395432 
    396         dt, t_start and t_stop are shared for all SpikeTrains 
     433        dt, t_start and t_stop are shared for all SpikeTrains in the SpikeList 
    397434 
    398435        """ 
    399436        self.id_list = id_list 
    400437        self.t_start = t_start 
    401         self.t_stop = t_stop 
    402         self.dt = dt 
    403         self.label = label 
    404         # transform spikes in a spike array 
     438        self.t_stop  = t_stop 
     439        self.dt      = dt 
     440        self.label   = label 
    405441        self.spiketrains = {} 
     442 
    406443        for id in id_list: 
    407444            self.spiketrains[id] = [] 
    408445        for id,time in spikes: 
    409             if id in id_list: #id_list can be a subset of the list of recorded neurons 
     446            if id in self.id_list: #id_list can be a subset of the list of recorded neurons 
    410447                self.spiketrains[id].append(time) 
    411448 
    412449        # writing as a list of SpikeTrains 
    413450        for id,spikes in self.spiketrains.items(): # 
    414             self.spiketrains[id] = SpikeTrain(spike_times=spikes, dt=self.dt, t_start=self.t_start, t_stop=self.t_stop) 
     451            self.spiketrains[id] = SpikeTrain(spikes, self.dt, self.t_start, self.t_stop) 
     452         
    415453        if len(self) > 0 and (self.t_start is None or self.t_stop is None): 
    416454            self.__calc_startstop() 
     455 
    417456 
    418457    def N(self): 
     
    428467            if self.t_start is None: 
    429468                start_times = numpy.array([self.spiketrains[idx].t_start for idx in self.id_list]) 
    430                 self.t_start = start_times.min() 
     469                self.t_start = numpy.min(start_times) 
     470                print "Warning, t_start is infered from the data : %f" %self.t_start 
    431471                for id in self.spiketrains.keys(): 
    432472                    self.spiketrains[id].t_start = self.t_start 
    433473            if self.t_stop is None: 
    434474                stop_times = numpy.array([self.spiketrains[idx].t_stop for idx in self.id_list]) 
    435                 self.t_stop  = stop_times.max() 
     475                self.t_stop  = numpy.max(stop_times) 
     476                print "Warning, t_stop  is infered from the data : %f" %self.t_stop 
    436477                for id in self.spiketrains.keys(): 
    437478                    self.spiketrains[id].t_stop = self.t_stop 
     
    439480            raise Exception("No SpikeTrains") 
    440481 
    441     def __getitem__(self, i): 
    442         return self.spiketrains[i] 
    443  
    444     def __setitem__(self, i, val): 
    445         assert isinstance(val, SpikeTrain), "A SpikeList object can only contain SpikeTrain objects" 
    446         self.spiketrains[i] = val 
    447         self.id_list.append(i) 
     482    def __getitem__(self, id): 
     483        return self.spiketrains[id] 
     484     
     485    #def __getslice__(self, i, j): 
     486    #def __setslice__(self, i, j): 
     487 
     488    def __setitem__(self, id, spktrain): 
     489        assert isinstance(spktrain, SpikeTrain), "A SpikeList object can only contain SpikeTrain objects" 
     490        self.spiketrains[id] = spktrain 
     491        if not id in self.id_list: 
     492            self.id_list.append(id) 
    448493        self.__calc_startstop() 
    449494 
     
    454499        return len(self.spiketrains) 
    455500 
    456     def append(self, id, spiketrain): 
     501    def __getdisplay__(self,display): 
     502        if display is False: 
     503            return None 
     504        elif display is True: 
     505            pylab.figure() 
     506            return pylab 
     507        else: 
     508            return display 
     509 
     510    def __labels__(self, subplot, xlabel, ylabel): 
     511        if hasattr(subplot, 'xlabel'): 
     512            subplot.xlabel(xlabel, size="large") 
     513            subplot.ylabel(ylabel, size="large") 
     514        else: 
     515            subplot.set_xlabel(xlabel, size="large") 
     516            subplot.set_ylabel(ylabel, size="large") 
     517 
     518    def append(self, id, spktrain): 
    457519        """ 
    458520        Add a SpikeTrain object to the SpikeList 
    459         """ 
    460         assert isinstance(spiketrain, SpikeTrain), "A SpikeList object can only contain SpikeTrain objects" 
     521         
     522        See also 
     523            concatenate 
     524        """ 
     525        assert isinstance(spktrain, SpikeTrain), "A SpikeList object can only contain SpikeTrain objects" 
    461526        if id in self.id_list: 
    462             raise Exception("Id already present in SpikeList.Use setitem instead()"
    463         else: 
    464             self.spiketrains[id] = spiketrain 
     527            raise Exception("id %d already present in SpikeList. Use setitem instead()" %id
     528        else: 
     529            self.spiketrains[id] = spktrain 
    465530            self.id_list.append(id) 
    466         self.t_start = min(self.t_start, spiketrain.t_start) or spiketrain.t_start # in case self.t_start is None 
    467         self.t_stop = max(self.t_stop, spiketrain.t_stop) 
     531        self.t_start = min(self.t_start, spktrain.t_start) or spktrain.t_start # in case self.t_start is None 
     532        self.t_stop = max(self.t_stop, spktrain.t_stop) 
    468533 
    469534    def get_time_parameters(self): 
    470535        """ 
    471         Returns the time parameters of the SpikeList (t_start, t_stop, dt) 
     536        Return the time parameters of the SpikeList (t_start, t_stop, dt) 
    472537        """ 
    473538        return (self.t_start, self.t_stop, self.dt) 
    474539 
    475     # Same as for the SpikeTrain object 
    476540    def time_axis(self, time_bin): 
     541        """ 
     542        Return a time_axis between t_start and t_stop with bin of size time_bin 
     543        """ 
    477544        return numpy.arange(self.t_start, self.t_stop, time_bin) 
    478545 
    479     def concatenate(self, SpikeList_list): 
    480         """ 
    481         Concatenation of a list of SpikeLists to the current SpikeList 
    482         """ 
     546    def concatenate(self, spklists): 
     547        """ 
     548        Concatenation of a SpikeLists to the current SpikeList. 
     549        SpikeLists could be a single SpikeList or a list of SpikeLists 
     550         
     551        See also 
     552            append 
     553        """ 
     554        if isinstance(spklists, SpikeList): 
     555            spklists = [SpikeLists] 
    483556        # We check that Spike Lists have similar time_axis 
    484         sl_= SpikeList_list[0] 
    485         for sl in SpikeList_list: 
     557        for sl in spklists: 
    486558            if not sl.get_time_parameters() == self.get_time_parameters(): 
    487559                raise Exception("Spike Lists should have similar time_axis") 
    488         for sl in SpikeList_list
     560        for sl in spklists
    489561            for id in sl.id_list: 
    490562                self.append(id, sl.spiketrains[id]) 
     
    508580    def idsubSpikeList(self, id_list): 
    509581        """ 
    510         Generate a new SpikeList truncated from a particular sublist of cells 
    511         """ 
    512         # We check what are the elements that are in self.id_list and not in 
    513         # id_list. We remove such elements from the SpikeList 
     582        Generate a new SpikeList truncated from a particular sublist of cells. The 
     583        new sub SpikeList keeps the time parameters of the old one (dt, t_start, t_stop) 
     584         
     585        id_list could be an integer, and N cells will be randomly selected, or 
     586        a sublist of the ids. 
     587         
     588        See also 
     589            timesubSpikeList 
     590        """ 
    514591        new_SpkList = SpikeList([], [], self.dt, self.t_start, self.t_stop) 
    515         if isinstance(id_list,int): 
     592        if isinstance(id_list, int): 
    516593            id_list = numpy.random.permutation(self.id_list)[0:id_list] 
    517594        for id in id_list: 
     
    519596                new_SpkList.append(id, self.spiketrains[id]) 
    520597            except Exception: 
    521                 print "Item %s is not in the source SpikeList" %id 
     598                print "id %d is not in the source SpikeList" %id 
    522599        return new_SpkList 
    523600 
     
    525602        """ 
    526603        Generate a new SpikeList truncated from particular time boundaries 
    527         returns a new SpikeList 
    528  
     604         
     605        See also 
     606            idsubSpikeList 
    529607        """ 
    530608        new_SpkList = SpikeList([], [], self.dt, t_start, t_stop) 
     
    540618 
    541619     
    542     def isi(self, nbins=100, display=False): 
     620    def isi(self, nbins=100): 
    543621        """ 
    544622        Return the list of all the isi vectors for all the SpikeTrains objects 
    545         within the SpikeList. If display is True, then it plots the distribution 
    546         of the ISI 
     623        within the SpikeList. 
     624         
     625        See also: 
     626            isi_hist 
    547627        """ 
    548628        isis = [] 
    549         for idx,id in enumerate(self.id_list)
     629        for id in self.id_list
    550630            isis.append(self.spiketrains[id].isi()) 
    551         if not display: 
    552             return isis 
    553         else: 
    554             ISI = numpy.array([]) 
    555             for idx in xrange(self.N()): 
    556                 ISI = numpy.concatenate((ISI,isis[idx])) 
    557             values, xaxis = numpy.histogram(ISI, nbins, normed=1) 
    558             pylab.figure() 
    559             pylab.plot(xaxis, values) 
    560             pylab.xlabel("Inter Spike Interval (ms)", size="x-large") 
    561             pylab.ylabel("% of Neurons", size="x-large") 
    562  
    563     def isi_hist(self, bins): 
    564         """Return the histogram of the ISI. 
    565          
     631        return isis 
     632 
     633 
     634    def isi_hist(self, bins=50, display=False, kwargs={}): 
     635        """ 
     636        Return the histogram of the ISI. 
    566637        bins may either be an integer, giving the number of bins (between the 
    567638        min and max of the data) or a list/array containing the lower edges of 
    568639        the bins. 
    569          
    570         Returns a tuple, (histogram values, bin edges) 
    571         """ 
    572         isis = numpy.concatenate(self.isi()) 
    573         return numpy.histogram(isis, bins=bins) 
     640        If display is True or is a plot object, then the histogram is plotted 
     641        Otherwise, the function returns a tuple, (histogram values, bin edges) 
     642         
     643        See also: 
     644            isi 
     645        """ 
     646        isis          = numpy.concatenate(self.isi()) 
     647        values, xaxis = numpy.histogram(isis, bins=bins, normed=True) 
     648        subplot       = self.__getdisplay__(display) 
     649        if not subplot: 
     650            return values, xaxis 
     651        else: 
     652            xlabel = "Inter Spike Interval (ms)" 
     653            ylabel = "% of Neurons" 
     654            self.__labels__(subplot, xlabel, ylabel) 
     655            subplot.plot(xaxis, values, **kwargs) 
     656            pylab.draw() 
    574657 
    575658     
    576     def cv_isi(self, nbins=100, display=False): 
     659    def cv_isi(self, nbins=100): 
    577660        """ 
    578661        Return the list of all the cv coefficients for all the SpikeTrains objects 
    579         within the SpikeList. If display is True, then it plots the distribution 
    580         of the CVs 
     662        within the SpikeList. 
     663         
     664        See also: 
     665            cv_isi_hist 
    581666        """ 
    582667        cvs_isi = [] 
     
    585670            if len(isi) > 1: 
    586671                cvs_isi.append(numpy.std(isi)/numpy.mean(isi)) 
    587         if not display: 
    588             return cvs_isi 
    589         else: 
    590             CV = numpy.array([]) 
    591             for idx in xrange(len(cvs_isi)): 
    592                 CV = numpy.concatenate((CV,[cvs_isi[idx]])) 
    593             values, xaxis = numpy.histogram(CV, nbins, normed=1) 
    594             pylab.figure() 
    595             pylab.plot(xaxis, values) 
    596             pylab.xlabel("Inter Spike Interval CV", size="x-large") 
    597             pylab.ylabel("% of Neurons", size="x-large") 
    598  
    599     def cv_isi_hist(self, bins): 
    600         cvs = numpy.array(self.cv_isi()) 
    601         return numpy.histogram(cvs, bins=bins) 
     672        return cvs_isi 
     673 
     674 
     675    def cv_isi_hist(self, bins=50, display=False, kwargs={}): 
     676        """ 
     677        Return the histogram of the cv coefficients. 
     678        bins may either be an integer, giving the number of bins (between the 
     679        min and max of the data) or a list/array containing the lower edges of 
     680        the bins. 
     681        If display is True or is a plot object, then the histogram is plotted 
     682        Otherwise, the function returns a tuple, (histogram values, bin edges) 
     683         
     684        See also: 
     685            cv_isi 
     686        """ 
     687        cvs           = numpy.array(self.cv_isi()) 
     688        values, xaxis = numpy.histogram(cvs, bins=bins, normed=True) 
     689        subplot       = self.__getdisplay__(display) 
     690        if not subplot: 
     691            return values, xaxis 
     692        else: 
     693            xlabel = "Inter Spike Interval (ms)" 
     694            ylabel = "% of Neurons" 
     695            self.__labels__(subplot, xlabel, ylabel) 
     696            subplot.plot(xaxis, values, **kwargs) 
     697            pylab.draw() 
     698             
    602699 
    603700    def time_axis(self, time_bin): 
    604701        return numpy.arange(self.t_start, self.t_stop, time_bin) 
    605702 
     703 
    606704    def mean_rate(self, t_start=None, t_stop=None): 
    607705        """ 
    608         Return the mean firing rate averaged accross all SpikeTrains 
    609  
    610         see mean_rates 
     706        Return the mean firing rate averaged accross all SpikeTrains between t_start and t_stop. 
     707 
     708        See also 
     709            mean_rates 
    611710        """ 
    612711        return numpy.mean(self.mean_rates(t_start, t_stop)) 
    613      
     712 
     713 
    614714    def mean_rate_std(self, t_start=None, t_stop=None): 
    615715        """ 
    616         Std deviation of the Mean firing rate averaged accross all SpikeTrains 
    617  
    618         see mean_rate 
     716        Std deviation of the Mean firing rate averaged accross all SpikeTrains between t_start and t_stop 
     717 
     718        See also 
     719            mean_rate 
    619720        """ 
    620721        return numpy.std(self.mean_rates(t_start, t_stop)) 
    621      
     722 
     723 
    622724    def mean_rates(self, t_start=None, t_stop=None): 
    623         """ returns a vector of the size of id_list giving the mean rate for each neuron 
    624  
    625         see SpikeTrain.mean_rate 
     725        """  
     726        Returns a vector of the size of id_list giving the mean rate for each neuron 
     727 
     728        See also 
     729            SpikeTrain.mean_rate 
    626730        """ 
    627731        rates = [] 
    628732        for id in self.id_list: 
    629733            rates.append(self.spiketrains[id].mean_rate(t_start, t_stop)) 
    630  
    631734        return rates 
    632735     
    633     def rate_distribution(self, nbins=25, normalize=True, display=False): 
     736     
     737    def rate_distribution(self, nbins=25, normalize=True, display=False, kwargs={}): 
    634738        """ 
    635739        Return a vector with all the mean firing rates for all SpikeTrains. 
    636740        If display is True, then it plots the distribution of the rates 
    637741        """ 
    638         #rates = numpy.zeros(self.N(), float) 
    639         #for idx,id in enumerate(self.id_list): 
    640         #    rates[idx] = self.spiketrains[id].mean_firing_rate() 
    641         rates = self.mean_rates() 
    642         if not display: 
     742        rates   = self.mean_rates() 
     743        subplot = self.__getdisplay__(display) 
     744        if not subplot: 
    643745            return rates 
    644746        else: 
    645             values, xaxis = numpy.histogram(rates, nbins, normed=1) 
    646             pylab.plot(xaxis,values) 
    647             pylab.ylabel("% of Neurons", size="x-large") 
    648             pylab.xlabel("Average Firing Rate (Hz)", size="x-large") 
    649  
    650     def spike_histogram(self, time_bin, normalized=False, display=False): 
     747            values, xaxis = numpy.histogram(rates, nbins, normed=True) 
     748            xlabel = "Average Firing Rate (Hz)" 
     749            ylabel = "% of Neurons" 
     750            self.__labels__(subplot, xlabel, ylabel) 
     751            subplot.plot(xaxis, values, **kwargs) 
     752            pylab.draw() 
     753             
     754 
     755 
     756    def spike_histogram(self, time_bin, normalized=False, display=False, kwargs={}): 
    651757        """ 
    652758        Generate an array with all the spike_histograms of all the SpikeTrains 
    653759        objects within the SpikeList. If display is True, then it plots the 
    654760        mean firing rate of the all population along time 
     761         
     762        See also 
     763            firing_rate 
    655764        """ 
    656765        if hasattr(time_bin, '__len__'): 
     
    660769        spike_hist = numpy.zeros((self.N(), nbins), float) 
    661770        logging.debug("nbins = %d" % nbins) 
     771        subplot = self.__getdisplay__(display) 
    662772        for idx,id in enumerate(self.id_list): 
    663773            try: 
     
    668778                print self.spiketrains[id].t_start, self.spiketrains[id].t_stop 
    669779                raise 
    670         if display: 
    671             pylab.figure() 
    672             pylab.plot(self.time_axis(time_bin),sum(spike_hist)/self.N()) 
    673             pylab.ylabel("Mean Number of Spikes per bin", size="x-large") 
    674             pylab.xlabel("Time (ms)", size="x-large") 
    675         return spike_hist 
    676  
    677     def firing_rate(self, time_bin, display=False): 
     780        if not subplot: 
     781            return spike_hist 
     782        else: 
     783            ylabel = "Spikes per bin" 
     784            xlabel = "Time (ms)" 
     785            self.__labels__(subplot, xlabel, ylabel) 
     786            subplot.plot(self.time_axis(time_bin),sum(spike_hist)/self.N(),**kwargs) 
     787            pylab.draw() 
     788             
     789 
     790    def firing_rate(self, time_bin, display=False, kwargs={}): 
    678791        """ 
    679792        Calculate firing rate traces (in Hz) from arrays of spike times. 
     
    684797        >>> pylab.plot(output[0].time_axis(dt),sum(output.firing_rate(dt))) 
    685798        """ 
    686         return self.spike_histogram(time_bin, normalized=True, display=display) 
     799        return self.spike_histogram(time_bin, normalized=True, display=display, kwargs=kwargs) 
     800 
    687801 
    688802    # Compute the Fano Factor of the population. Need to be checked 
     
    692806        fano = numpy.var(firing_rate)/numpy.mean(firing_rate) 
    693807        return fano 
    694      
     808 
     809 
    695810    def fano_factors_isi(self): 
    696         """ returns a list containing the fano factors for each neuron""" 
     811        """  
     812        Return a list containing the fano factors for each neuron 
     813         
     814        See also 
     815            isi, isi_cv 
     816        """ 
    697817        fano_factors = [] 
    698818        for id in self.id_list: 
     
    720840 
    721841     
    722     def activity_map(self, dims, bounds=None, display=False): 
    723         """ 
    724         Generate a map of the activity during t_start and t_stop. 
    725         If dims is a tuple, then cells are placed on a grid of size 
     842    def activity_map(self, dims, t_start=None, t_stop=None, bounds=None, display=False): 
     843        """ 
     844        Generate a 2D map of the activity averaged between t_start and t_stop. 
     845        If t_start and t_stop are not defined, we used those of the SpikeList object 
     846 
     847        if dims is a tuple, then cells are placed on a grid of size 
    726848        (N,M), else if dims is an array of size (2,nb_cells) with the 
    727849        x (first line) and y (second line) flotting positions of the cells, 
    728         we generate a scatter plot. bounds is a parameters allowing to specify 
     850        we generate a scatter plot.  
     851         
     852        bounds is a parameters allowing to specify 
    729853        the range of the colorbar 
    730         """ 
     854         
     855        See also 
     856            activity_movie 
     857        """ 
     858        subplot = self.__getdisplay__(display) 
    731859        if isinstance(dims, tuple) or isinstance(dims, list): 
    732860            activity_map = numpy.zeros(dims,float) 
    733             rates = self.mean_rates() 
     861            rates        = self.mean_rates() 
    734862            for id in self.id_list: 
    735863                position = self.id2position(id, dims) 
    736864                activity_map[position] = rates[id] 
    737             if not display
     865            if not subplot
    738866                return activity_map 
    739867            else: 
    740                 pylab.figure() 
    741                 pylab.imshow(activity_map,interpolation='bicubic') 
    742                 pylab.colorbar() 
    743                 pylab.show() 
     868                subplot.imshow(activity_map, interpolation='bicubic') 
     869                subplot.colorbar() 
     870                subplot.show() 
    744871                if bounds is not None: 
    745                     pylab.clim(bounds) 
     872                    subplot.clim(bounds) 
    746873        elif isinstance(dims, numpy.ndarray): 
    747874            if not len(self.id_list) == len(dims[0]): 
    748875                raise Exception("Error, the number of positions does not match the number of cells in the SpikeList") 
    749876            rates = self.mean_rates() 
    750             #for id in self.id_list: 
    751             #    rates.append(self.spiketrains[id].mean_firing_rate()) 
    752             if not display: 
     877            if not subplot: 
    753878                return rates 
    754879            else: 
    755880                x = dims[0,:] 
    756881                y = dims[1,:] 
    757                 pylab.scatter(x,y,c=rates) 
    758                 pylab.colorbar() 
     882                subplot.scatter(x,y,c=rates) 
     883                subplot.colorbar() 
    759884                if bounds is not None: 
    760                     pylab.clim(bounds) 
     885                    subplot.clim(bounds) 
    761886 
    762887    def pairwise_correlations(self, pairs, time_bin=1., display=False): 
     
    774899        hist_1 = spk1.spike_histogram(time_bin) 
    775900        hist_2 = spk2.spike_histogram(time_bin) 
    776         print spk1, spk2 
    777901        length = 2*len(hist_1[0]) 
     902        subplot = self.__getdisplay__(display) 
    778903        results = numpy.zeros((nb_pairs,length), float) 
    779904        for idx in xrange(nb_pairs): 
     
    782907            if sum(hist_1[idx]) > 0 and sum(hist_2[idx] > 0): 
    783908                results[idx,:] = ccf(hist_1[idx],hist_2[idx]) 
    784         if (display)
     909        if subplot
    785910            results = sum(results)/nb_pairs 
    786911            pylab.figure() 
    787912            xaxis  = time_bin*numpy.arange(-len(results