Changeset 99

Show
Ignore:
Timestamp:
01/15/08 15:47:03 (1 year ago)
Author:
apdavison
Message:

Added more unit tests for the SpikeTrain class

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/spikes.py

    r98 r99  
    5555                raise ValueError("dt must be greater than zero") 
    5656            #self.spike_times *= dt 
    57 ##        else: 
    58 ##            # knowing dt may be useful (for the spike matrix) 
    59 ##            dt = 1e4 # 0.1 ms per default 
    60 ##            # TODO : compute as the best descriptor for a list of floats 
    6157 
    6258        self.dt = dt 
     
    10197            if self.t_start is None: 
    10298                self.t_start = numpy.min(self.spike_times) 
     99            if numpy.any(self.spike_times < self.t_start): 
     100                raise ValueError("Spike times must not be less than t_start") 
    103101            if self.t_stop is None: 
    104102                self.t_stop = numpy.max(self.spike_times) 
    105         # TODO raise an error if some data is outside [t_start, t_stop] ? 
    106         #TODO return an exception if self.t_stop < self.t_start (when not empty) 
     103            if numpy.any(self.spike_times > self.t_stop): 
     104                raise ValueError("Spike times must not be greater than t_stop") 
     105         
    107106        if self.t_start >= self.t_stop : 
    108107            raise Exception("Incompatible time interval for the creation of the SpikeTrain") 
    109108        if self.t_start < 0: 
    110109            raise ValueError("t_start must not be negative") 
     110        if numpy.any(self.spike_times < 0): 
     111            raise ValueError("Spike times must not be negative") 
    111112 
    112113    def __str__(self): 
     
    118119    def format(self, relative=False, quantized=False): 
    119120        """ 
    120         a function to format a spike train from a format to another. 
    121         outputs a list. 
     121        a function to format a spike train from one format to another. 
     122        outputs a numpy array 
    122123 
    123124        """ 
    124125        spike_times = self.spike_times.copy() 
     126        spike_times.sort() 
    125127 
    126128        if relative and len(spike_times)>0: 
    127             spike_times[1:] = self.spike_times[1:] - self.spike_times[:-1] 
     129            spike_times[1:] = spike_times[1:] - spike_times[:-1] 
    128130 
    129131        if quantized: 
    130             spike_times =  numpy.array([time/self.dt for time in spike_times],int) 
     132            assert quantized > 0, "quantized must either be False or a positive number" 
     133            #spike_times =  numpy.array([time/self.quantized for time in spike_times],int) 
     134            spike_times = (spike_times/quantized).round().astype('int') 
    131135 
    132136        return spike_times 
     
    138142        # TODO this needs some thinking to know how to handle the border, in particular the 
    139143        # first spike and t_start 
    140         return self.format(relative=True)[1:] 
     144        return self.format(relative=True, quantized=False)[1:] 
    141145 
    142146    # Returns the mean firing rate of the SpikeTrain 
  • trunk/test/test_base_classes.py

    r87 r99  
    88 
    99def arrays_are_equal(a, b): 
     10    a.sort() 
     11    b.sort() 
    1012    return (a == b).all() 
    1113 
    1214def arrays_almost_equal(a,b, threshold): 
     15    a.sort() 
     16    b.sort() 
    1317    diff = numpy.abs(a - b) 
    1418    return ((threshold - diff) > 0).all() 
     
    7478        self.assertRaises(ValueError, SpikeTrain, [0, 10, 30, 60, 15], dt=0.01, t_stop=0.5) 
    7579     
    76     #def testInitWithTStartGreaterThanMinSpikeTime(self): raise Exception() 
     80    def testInitWithTStartGreaterThanMinSpikeTime(self): 
     81        self.assertRaises(ValueError, SpikeTrain, [0, 10, 30, 60, 15], dt=0.01, t_start=0.1) 
    7782         
     83    def testFormatRelative(self): 
     84        self.assert_( arrays_almost_equal(self.s1.format(relative=True, quantized=False), 
     85                                          numpy.array([0.0, 0.1, 0.05, 0.15, 0.3]), 
     86                                          1e-12)) 
     87     
     88    def testFormatQuantized(self): 
     89        self.assert_( arrays_almost_equal(self.s1.format(relative=False, quantized=0.01), 
     90                                          numpy.array([0, 10, 30, 60, 15]), 
     91                                          1e-12)) 
     92        self.assert_( arrays_almost_equal(self.s1.format(relative=False, quantized=0.05), 
     93                                          numpy.array([0, 2, 6, 12, 3]), 
     94                                          1e-12)) 
     95        self.assert_( arrays_almost_equal(self.s1.format(relative=False, quantized=0.1), 
     96                                          numpy.array([0, 1, 3, 6, 1]), 
     97                                          1+1e-12)) 
     98     
     99    def testFormatStandard(self): 
     100        self.assert_( arrays_are_equal(self.s1.spike_times, 
     101                                          self.s1.format(relative=False, quantized=False)) ) 
     102         
     103    def testMeanFiringRate(self): 
     104        self.assertAlmostEqual(self.s1.mean_firing_rate(), 8333.3, places=1) 
     105     
     106    def testCVISI(self): 
     107        # note I did not check this value by hand 
     108        self.assertAlmostEqual(self.s1.cv_isi(), 0.6236, places=4) 
     109         
     110    def testTimeAxis(self): 
     111        self.assert_( arrays_are_equal(self.s1.time_axis(0.1), 
     112                                       numpy.arange(0.0,0.6,0.1)) ) 
     113         
     114    def testSlice(self): 
     115        s2 = SpikeTrain([0.15, 0.3]) 
     116        s3 = SpikeTrain([0.1, 0.15, 0.3]) 
     117        s4 = self.s1.subSpikeTrain(0.11,0.4) # should not include 0.1 
     118        s5 = self.s1.subSpikeTrain(0.10,0.4) # should include 0.1 
     119        self.assert_( arrays_are_equal(s2.spike_times, 
     120                                       s4.spike_times) ) 
     121        self.assert_( arrays_are_equal(s3.spike_times, 
     122                                       s5.spike_times) ) 
     123     
    78124         
    79125if __name__ == "__main__":