Changeset 307

Show
Ignore:
Timestamp:
11/08/08 10:58:36 (2 months ago)
Author:
emuller
Message:

Fixed inh_gamma_generator hazard function numerics
Fixed poisson_process ISI buffer overrun
Added stgen module docstring
Added more tests
Added inh_gamma_psth.py example

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/__init__.py

    r304 r307  
    66 
    77# The nice thing would be to gathered every non standard 
    8 # dependency here, in order to centralizz the warning 
     8# dependency here, in order to centralize the warning 
    99# messages and the check 
    1010dependencies = {'pylab' : {'website' : 'http://matplotlib.sourceforge.net/', 'is_present' : False, 'check':False}, 
     
    1616                'NeuroTools.facets.hdf5' : {'website' : None, 'is_present' : False, 'check':False}, 
    1717                'srblib' : {'website' : 'http://www.sdsc.edu/srb/index.php/Python', 'is_present' : False, 'check':False}, 
     18                'rpy' : {'website' : 'http://rpy.sourceforge.net/', 'is_present' : False, 'check':False}, 
     19 
    1820                ## Add here your extensions ### 
    1921               } 
  • trunk/src/stgen.py

    r305 r307  
     1""" 
     2NeuroTools.stgen 
     3================ 
     4 
     5A collection of tools for stochastic process generation. 
     6 
     7 
     8Classes 
     9------- 
     10 
     11StGen - Object to generate stochastic processes of various kinds 
     12        and return them as SpikeTrain or AnalogSignal objects. 
     13 
     14 
     15Functions 
     16--------- 
     17 
     18shotnoise_fromspikes - Convolves the provided spike train with shot decaying exponential. 
     19 
     20gamma_hazard - Compute the hazard function for a gamma process with parameters a,b. 
     21""" 
     22 
     23 
    124# TODO : needs refactoring 
    225# TODO: needs python gsl wrapper / convert to numpy 
     
    932 
    1033 
    11 def gamma_hazard(x, a, b, dt=1e-4): 
     34def gamma_hazard_scipy(x, a, b, dt=1e-4): 
    1235    """ 
    1336    Compute the hazard function for a gamma process with parameters a,b 
     
    2144    """ 
    2245 
    23     # TODO: this algorithm has numerical problems 
     46    # This algorithm is presently not used by 
     47    # inh_gamma_generator as it has numerical problems 
    2448    # Try:  
    2549    # plot(stgen.gamma_hazard(arange(0,1000.0,0.1),10.0,1.0/50.0)) 
     
    3963    else: 
    4064        return val 
     65 
     66 
     67def gamma_hazard(x, a, b, dt=1e-4): 
     68    """ 
     69    Compute the hazard function for a gamma process with parameters a,b 
     70    where a and b are the parameters of the gamma PDF: 
     71    y(t) = x^(a-1) \exp(-x/b) / (\Gamma(a)*b^a) 
     72 
     73    Inputs: 
     74        x   - in units of seconds 
     75        a   - dimensionless 
     76        b   - in units of seconds 
     77    """ 
     78 
     79    # Used by inh_gamma_generator 
     80 
     81    # Ideally, I would like to see an implementation which does not depend on RPy 
     82    # but the gamma_hazard_scipy above using scipy exhibits numerical problems, as it does not 
     83    # support directly returning the log. 
     84 
     85    if not check_dependency('rpy'): 
     86        raise ImportError("gamma_hazard requires RPy (http://rpy.sourceforge.net/)") 
     87 
     88    from rpy import r 
     89 
     90    # scipy.special.gammaincc has numerical problems 
     91    #Hpre = -log(scipy.special.gammaincc(a,(x-dt)/b)) 
     92    #Hpost = -log(scipy.special.gammaincc(a,(x+dt)/b)) 
     93 
     94    # reverting to the good old r.pgamma 
     95    Hpre = -r.pgamma(x-dt,shape=a,scale=b,lower=False,log=True) 
     96    Hpost = -r.pgamma(x+dt,shape=a,scale=b,lower=False,log=True) 
     97    val =  0.5*(Hpost-Hpre)/dt 
     98 
     99    return val 
     100 
     101 
    41102     
    42103 
     
    45106    def __init__(self, rng=None, seed=None): 
    46107        """  
    47         Spike Train Generator 
    48         Object to generate spiking random processes with various statistics  
    49         (inhomogeneous poisson, inhomogeneous gamma, etc.) as SpikeTrain objects. 
     108        Stochastic Process Generator 
     109        ============================ 
     110 
     111        Object to generate stochastic processes of various kinds 
     112        and return them as SpikeTrain or AnalogSignal objects. 
     113       
    50114 
    51115        Inputs: 
    52             rng - Seed for the random number generator. Can be None, or  
     116            rng - The random number generator state object (optional). Can be None, or  
    53117                  a numpy.random.RandomState object, or an object with the same  
    54118                  interface. 
     119 
     120            seed - A seed for the rng (optional). 
    55121 
    56122        If rng is not None, the provided rng will be used to generate random numbers,  
     
    60126        Examples: 
    61127            >> x = StGen() 
     128 
     129 
     130 
     131        StGen Methods: 
     132 
     133        Spiking point processes: 
     134        ------------------------ 
     135  
     136        poisson_generator - homogeneous Poisson process 
     137        inh_poisson_generator - inhomogeneous Poisson process (time varying rate) 
     138        inh_gamma_generator - inhomogeneous Gamma process (time varying a,b) 
     139 
     140        Continuous time processes: 
     141        -------------------------- 
     142 
     143        OU_generator - Ohrnstein-Uhlenbeck process 
     144         
     145 
     146        See also: 
     147          shotnoise_fromspikes 
     148 
    62149        """ 
    63150 
     
    75162 
    76163 
    77     def poisson_generator(self, rate, t_start=0.0, t_stop=1000.0, array=False): 
     164    def poisson_generator(self, rate, t_start=0.0, t_stop=1000.0, array=False,debug=False): 
    78165        """ 
    79166        Returns a SpikeList whose spikes are a realization of a Poisson process 
     
    98185        """ 
    99186 
    100         number = int((t_stop-t_start)/1000.0*2.0*rate) 
     187        #number = int((t_stop-t_start)/1000.0*2.0*rate) 
     188         
     189        # less wasteful than double length method above 
     190        n = (t_stop-t_start)/1000.0*rate 
     191        number = numpy.ceil(n+3*numpy.sqrt(n)) 
     192        if number<100: 
     193            number = min(5+numpy.ceil(2*n),100) 
     194         
    101195        if number > 0: 
    102196            isi = self.rng.exponential(1.0/rate, number)*1000.0 
     
    111205        i = numpy.searchsorted(spikes, t_stop) 
    112206 
     207        extra_spikes = [] 
    113208        if i==len(spikes): 
    114             raise RuntimeError("Internal ISI buffer overrun.  Please file a bug report ticket at http://neuralensemble.org/NeuroTools.") 
    115  
    116         if array: 
    117             return numpy.resize(spikes,(i,)) 
    118  
    119         return SpikeTrain(numpy.resize(spikes,(i,)), t_start=t_start,t_stop=t_stop) 
     209            # ISI buf overrun 
     210             
     211            t_last = spikes[-1] + self.rng.exponential(1.0/rate, 1)[0]*1000.0 
     212 
     213            while (t_last<t_stop): 
     214                extra_spikes.append(t_last) 
     215                t_last += self.rng.exponential(1.0/rate, 1)[0]*1000.0 
     216             
     217            spikes = numpy.concatenate((spikes,extra_spikes)) 
     218 
     219            if debug: 
     220                print "ISI buf overrun handled. len(spikes)=%d, len(extra_spikes)=%d" % (len(spikes),len(extra_spikes)) 
     221 
     222 
     223        else: 
     224            spikes = numpy.resize(spikes,(i,)) 
     225 
     226        if not array: 
     227            spikes = SpikeTrain(spikes, t_start=t_start,t_stop=t_stop) 
     228 
     229 
     230        if debug: 
     231            return spikes, extra_spikes 
     232        else: 
     233            return spikes 
    120234 
    121235 
     
    270384    # TODO: provide optimized C/weave implementation if possible 
    271385 
    272     inh_gamma_generator = _inh_gamma_generator_python 
     386 
     387    def _inh_gamma_generator_unimplemented(self, a, b, t, t_stop, array=False): 
     388        raise Exception("inh_gamma_generator is disabled as dependency RPy was not found.") 
     389 
     390    if check_dependency('rpy'): 
     391        inh_gamma_generator = _inh_gamma_generator_python 
     392    else: 
     393        inh_gamma_generator = _inh_gamma_generator_unimplemented 
    273394 
    274395 
     
    432553    OU_generator = _OU_generator_python2 
    433554 
    434     # TODO: inhomogeneous OU generator 
     555    # TODO: optimized inhomogeneous OU generator 
    435556 
    436557 
     
    439560# TODO fix shotnoise stuff below  ... and write tests 
    440561 
    441     def shotnoise_generator(self,rate,tau,q,num,t): 
    442         """  
    443         Generate shotnoise 
    444  
    445         Inputs: 
    446             rate - the rate (in Hz) of the shotnoise 
    447             tau  - the exponential decay of the synapse 
    448             g    - the quantal increase for a spike 
    449  
    450         quantal-increase-"q"-exponential-decay-"tau" synapse 
    451         and a poisson spike train of "rate" and "num" """ 
    452  
    453         g = numpy.zeros(numpy.shape(t),float) 
    454         for i in xrange(num): 
    455  
    456             spikes = poisson_generator(rate,t[-1]) 
    457             dg=exp_conv(spikes,t,q,tau) 
    458             tmp = add(g,dg,g) 
    459  
    460         return g 
     562def shotnoise_fromspikes(self,rate,tau,q,num,t): 
     563    """  
     564    Generate shotnoise 
     565 
     566    Inputs: 
     567        rate - the rate (in Hz) of the shotnoise 
     568        tau  - the exponential decay of the synapse 
     569        g    - the quantal increase for a spike 
     570 
     571    quantal-increase-"q"-exponential-decay-"tau" synapse 
     572    and a poisson spike train of "rate" and "num" """ 
     573 
     574    g = numpy.zeros(numpy.shape(t),float) 
     575    for i in xrange(num): 
     576 
     577        spikes = poisson_generator(rate,t[-1]) 
     578        dg=exp_conv(spikes,t,q,tau) 
     579        tmp = add(g,dg,g) 
     580 
     581    return g 
    461582 
    462583 
     
    466587 
    467588 
    468 def exp_conv(poisson_train,t,q,tau,eps = 1.0e-8): 
     589def shotnoise_fromspikes(spike_train,q,tau,dt,array=False, eps = 1.0e-8): 
    469590    """  
    470     Convolve poisson spike trains with shot decaying exponentials 
    471     t must be equally spaced arrayrange 
    472     poisson spike times must all in the range of t 
    473     otherwise unpredicted results. 
    474     """ 
    475  
    476     dt = t[1]-t[0] 
     591    Convolves the provided spike train with shot decaying exponentials 
     592    yielding so called shot noise if the spike train is Poisson-like.   
     593    Returns an AnalogSignal if array=False, otherwise (shotnoise,t) as numpy arrays.  
     594 
     595   Inputs: 
     596      spike_train - a SpikeTrain object 
     597      q - the shot jump for each spike 
     598      tau - the shot decay time constant in milliseconds 
     599      dt - the resolution of the resulting shotnoise in milliseconds 
     600      array - if True, returns (shotnoise,t) as numpy arrays, otherwise an AnalogSignal. 
     601      eps - a numerical parameter indicating at what value of  
     602      the shot kernal the tail is cut.  The default is usually fine. 
     603 
     604   Examples: 
     605      >> stg = stgen.StGen() 
     606      >> st = stg.poisson_generator(10.0,0.0,1000.0) 
     607      >> g_e = shotnoise_fromspikes(st,2.0,10.0) 
     608 
     609 
     610   See also: 
     611      poisson_generator, inh_gamma_generator, OU_generator ... 
     612   """ 
     613 
     614    st = spike_train 
     615 
     616    t = numpy.arange(st.t_start,st.t_stop,dt) 
    477617 
    478618    # time of vanishing significance 
    479     vs_t = -tau*log(eps/q) 
    480  
    481     kern = q*exp(-arrayrange(0.0,vs_t,dt)/tau) 
    482  
    483     idx = clip(searchsorted(t,poisson_train),0,len(t)-1) 
    484  
    485     a = zeros(shape(t),Float) 
    486  
    487     put(a,idx,1.0) 
    488  
    489     return convolve(a,kern)[0:len(t)] 
     619    vs_t = -tau*numpy.log(eps/q) 
     620 
     621    kern = q*numpy.exp(-numpy.arange(0.0,vs_t,dt)/tau) 
     622 
     623    idx = numpy.clip(numpy.searchsorted(t,poisson_train,'right')-1,0,len(t)-1) 
     624 
     625    a = numpy.zeros(shape(t),float) 
     626 
     627    a[idx] = 1.0 
     628 
     629    y = convolve(a,kern)[0:len(t)] 
     630 
     631    result = AnalogSignal(y,dt,t_start=0,t_stop=st.t_stop-st.t_start) 
     632    result.time_offset(st.t_start) 
     633    return result 
     634 
    490635 
    491636 
  • trunk/test/test_stgen.py

    r299 r307  
    4141        assert isinstance(st, numpy.ndarray) 
    4242 
     43        st = stg.poisson_generator(rate,0.0,t_stop,array=True,debug=True) 
     44         
     45        assert isinstance(st[0], numpy.ndarray) 
     46        assert isinstance(st[1], list) 
     47 
     48        st = stg.poisson_generator(rate,0.0,t_stop,debug=True) 
     49         
     50        assert isinstance(st[0], signals.SpikeTrain) 
     51        assert isinstance(st[1], list) 
     52 
    4353 
    4454    def testStatsPoisson(self): 
     55 
     56        # this is a statistical test with non-zero chance of failure 
     57 
     58        def test_poisson(rate,t_start,t_stop): 
     59            stg = stgen.StGen() 
     60            dt = t_stop-t_start 
     61            N = rate*dt/1000.0 
     62 
     63            st = stg.poisson_generator(rate,t_start=t_start,t_stop=t_stop,array=True) 
     64 
     65            if len(st) in (0,1,2,3): 
     66                assert N<15 
     67                return 
     68 
     69 
     70 
     71            assert st[-1] < t_stop 
     72            assert st[0] > t_start 
     73 
     74 
     75            # last spike should not be more than 4 ISI away from t_stop 
     76            err = """ 
     77    Last spike should not be more than 4 ISI behind t_stop. 
     78    There is a non-zero chance for this to occur during normal operation. 
     79    Re-run the test to see if the error persists.""" 
     80 
     81            if st[-1] < t_stop-4.0*1.0/rate*1000.0: 
     82                raise StatisticalError(err) 
     83 
     84 
     85            # first spike should not be more than 4 ISI away from t_start 
     86            err = """ 
     87    First spike should not be more than 4 ISI in front of t_start. 
     88    There is a non-zero chance for this to occur during normal operation. 
     89    Re-run the test to see if the error persists.""" 
     90 
     91            if st[0] > t_start+4.0*1.0/rate*1000.0: 
     92                raise StatisticalError(err) 
     93 
     94            err = """ 
     95    Number of spikes should be within 3 standard deviations of mean. 
     96    There is a non-zero chance for this to occur during normal operation. 
     97    Re-run the test to see if the error persists.""" 
     98 
     99 
     100            if len(st) > N+3.0*numpy.sqrt(N) or len(st) < N-3.0*numpy.sqrt(N): 
     101                raise StatisticalError(err) 
     102 
     103 
     104        # high rates 
     105 
     106        test_poisson(100.0,500.0,1500.0) 
     107 
     108        # high rates, short time 
     109 
     110        test_poisson(100.0,500.0,550.0) 
     111 
     112        # low rates, short time 
     113 
     114        test_poisson(2.0,500.0,550.0) 
     115 
     116        # low rates, long time 
     117 
     118        test_poisson(5.0,500.0,50500.0) 
     119 
     120 
     121 
     122 
     123    def testStatsInhPoisson(self): 
    45124 
    46125        # this is a statistical test with non-zero chance of failure 
     
    52131        t_stop = 1500.0 # milliseconds 
    53132 
    54         st = stg.poisson_generator(rate,t_start=t_start,t_stop=t_stop,array=True) 
    55  
    56         assert st[-1] < t_stop 
    57         assert st[0] > t_start 
    58  
    59  
    60         # last spike should not be more than 4 ISI away from t_stop 
    61         err = """ 
    62 Last spike should not be more than 4 ISI behind t_stop. 
    63 There is a non-zero chance for this to occur during normal operation. 
    64 Re-run the test to see if the error persists.""" 
    65  
    66         if st[-1] < t_stop-4.0*1.0/rate*1000.0: 
    67             raise StatisticalError(err) 
    68  
    69  
    70         # first spike should not be more than 4 ISI away from t_start 
    71         err = """ 
    72 First spike should not be more than 4 ISI in front of t_start. 
    73 There is a non-zero chance for this to occur during normal operation. 
    74 Re-run the test to see if the error persists.""" 
    75  
    76         if st[0] > t_start+4.0*1.0/rate*1000.0: 
    77             raise StatisticalError(err) 
    78  
    79         err = """ 
    80 Number of spikes should be within 3 standard deviations of mean. 
    81 There is a non-zero chance for this to occur during normal operation. 
    82 Re-run the test to see if the error persists.""" 
    83  
    84         # time interval is one second 
    85  
    86         if len(st) > rate+3.0*numpy.sqrt(rate) or len(st) < rate-3.0*numpy.sqrt(rate): 
    87             raise StatisticalError(err) 
    88  
    89  
    90     def testStatsInhPoisson(self): 
    91  
    92         # this is a statistical test with non-zero chance of failure 
    93  
    94         stg = stgen.StGen() 
    95  
    96         rate = 100.0 #Hz 
    97         t_start = 500.0 
    98         t_stop = 1500.0 # milliseconds 
    99  
    100133        st = stg.inh_poisson_generator(numpy.array([rate]),numpy.array([t_start]),t_stop,array=True) 
    101134