Changeset 180

Show
Ignore:
Timestamp:
07/23/08 19:01:56 (4 months ago)
Author:
apdavison
Message:

Minor improvements to spikes.py and analysis.py

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/analysis.py

    r151 r180  
    8585    os.remove(cfilename) 
    8686 
     87def _dict_max(D): 
     88    """ 
     89    For a dict containing numerical values, contain the key for the 
     90    highest value. If there is more than one item with the same highest 
     91    value, return one of them (arbitrary - depends on the order produced 
     92    by the iterator). 
     93    """ 
     94    max_val = max(D.values()) 
     95    for k in D: 
     96        if D[k] == max_val: 
     97            return k 
    8798 
    8899class TuningCurve(object): 
     
    122133            stderr[k] = arr.std()*n/(n-1)/numpy.sqrt(n) 
    123134        return mean, stderr 
     135 
     136    def max(self): 
     137        """Return the key of the max value and the max value.""" 
     138        k = _dict_max(self._tuning_curves) 
     139        return k, self._tuning_curves[k] 
  • trunk/src/spikes.py

    r179 r180  
    317317    ## Constructor and key methods to manipulate the SpikeList objects   ## 
    318318    ####################################################################### 
    319     def __init__(self, spikes, id_list, dt=None, t_start=None, t_stop=None): 
     319    def __init__(self, spikes, id_list, dt=None, t_start=None, t_stop=None, label=''): 
    320320        """ 
    321321        `spikes` is a list/tuple of (id,t) tuples (id in id_list) 
     
    329329        """ 
    330330        self.id_list = id_list 
    331         self.N = len(id_list) 
    332331        self.t_start = t_start 
    333332        self.t_stop = t_stop 
    334333        self.dt = dt 
     334        self.label = label 
    335335        # transform spikes in a spike array 
    336336        self.spiketrains = {} 
     
    346346        if len(self) > 0 and (self.t_start is None or self.t_stop is None): 
    347347            self.__calc_startstop() 
     348 
     349    def N(self): 
     350        return len(self.id_list) 
    348351 
    349352    def __calc_startstop(self): 
     
    451454        else: 
    452455            ISI = numpy.array([]) 
    453             for idx in xrange(self.N): 
     456            for idx in xrange(self.N()): 
    454457                ISI = numpy.concatenate((ISI,isis[idx])) 
    455458            values, xaxis = numpy.histogram(ISI, nbins, normed=1) 
     
    533536    # If display is True, then it plots the distribution of the rates 
    534537    def rate_distribution(self, nbins=25, normalize=True, display=False): 
    535         #rates = numpy.zeros(self.N, float) 
     538        #rates = numpy.zeros(self.N(), float) 
    536539        #for idx,id in enumerate(self.id_list): 
    537540        #    rates[idx] = self.spiketrains[id].mean_firing_rate() 
     
    550553    def spike_histogram(self, time_bin, normalized=False, display=False): 
    551554        nbins = numpy.ceil((self.t_stop-self.t_start)/time_bin) 
    552         firing_rate = numpy.zeros((self.N,nbins), float) 
    553  
     555        firing_rate = numpy.zeros((self.N(),nbins), float) 
     556        print "nbins = %d" % nbins 
    554557        for idx,id in enumerate(self.id_list): 
     558            print idx, id 
    555559            firing_rate[idx,:] = self.spiketrains[id].time_histogram(time_bin,normalized) 
    556560        if not display: 
     
    558562        else: 
    559563            pylab.figure() 
    560             pylab.plot(self.time_axis(time_bin),sum(firing_rate)/self.N
     564            pylab.plot(self.time_axis(time_bin),sum(firing_rate)/self.N()
    561565            pylab.ylabel("Mean Number of Spikes per bin", size="x-large") 
    562566            pylab.xlabel("Time (ms)", size="x-large") 
     
    677681    def mean_rate_variance(self, time_bin): 
    678682        firing_rate = self.firing_rate(time_bin) 
    679         return numpy.var(sum(firing_rate)/self.N
     683        return numpy.var(sum(firing_rate)/self.N()
    680684 
    681685    # Function to extract the covariance of the firing rate along time, 
     
    688692            raise Exception("Error, both SpikeLists should share common t_start, t_stop and dt") 
    689693        frate_1 = self.firing_rate(time_bin) 
    690         frate_1 = sum(frate_1)/self.N 
     694        frate_1 = sum(frate_1)/self.N() 
    691695        frate_2 = SpkList.firing_rate(time_bin) 
    692         frate_2 = sum(frate_2)/SpkList.N 
     696        frate_2 = sum(frate_2)/SpkList.N() 
    693697        N = len(frate_1) 
    694698        cov = sum(frate_1*frate_2)/N-sum(frate_1)*sum(frate_2)/(N*N) 
    695699        return cov 
    696700 
    697     # Function to generate a raster plot of a certain number of cells in the 
    698     # SpikeList object. If id_list is an integer, then N ids will be randomly choosen 
    699     # in id_list. If this is a list, those id will be selected. 
    700     # Raster is made between t_start and t_stop (region of interest, not the global ones) 
    701     # and colors can be a list of color (each for one cells) or a single string 
    702     # to apply the same color to all the raster plots 
     701     
    703702    def raster_plot(self, id_list=None, t_start=None, t_stop=None, colors='b', subplot=None, size=1): 
     703        """ 
     704        Generate a raster plot of a certain number of cells in the 
     705        SpikeList object. If id_list is an integer, then N ids will be randomly choosen 
     706        in id_list. If this is a list, those id will be selected. 
     707        Raster is made between t_start and t_stop (region of interest, not the global ones) 
     708        and colors can be a list of color (each for one cells) or a single string 
     709        to apply the same color to all the raster plots. 
     710        """ 
    704711        if id_list == None:  
    705712            id_list = self.id_list 
     
    723730            ids = ids[idx] 
    724731            if len(spike_times) > 0: 
     732                print "Plotting %d points for %s" % (len(spike_times), self.label) 
    725733                subplot.scatter(spike_times, ids, s=size, c=colors) 
    726734        elif len(colors) != len(id_list): 
     
    874882        """ 
    875883 
    876         spike_array = list([ [] for i in range(self.N)]) 
     884        spike_array = list([ [] for i in range(self.N())]) 
    877885        for spike in spikes: 
    878886            spike_array[spike[0]].append(spike[1]) 
     
    954962        id_list = range(id_list) 
    955963    spikes = readFile(filename) 
    956     return SpikeList(spikes, id_list, dt, t_start, t_stop) 
     964    print "Read %d spikes from %s" % (len(spikes), filename) 
     965    return SpikeList(spikes, id_list, dt, t_start, t_stop, label=filename) 
    957966 
    958967