root/trunk/src/random.py

Revision 353, 4.5 kB (checked in by apdavison, 2 months ago)

Minor improvements to parameters and random modules

  • Property svn:eol-style set to native
Line 
1 """
2 NeuroTools.random
3 =====================
4
5 A set of classes representing statistical distributions, with an interface that
6 is compatible with the ParameterSpace class in the parameters module.
7
8 Classes
9 -------
10
11 GammaDist   - gamma.pdf(x,a,b) = x**(a-1)*exp(-x/b)/gamma(a)/b**a
12 NormalDist  - normal distribution
13 UniformDist - uniform distribution
14
15 """
16
17 from NeuroTools import check_dependency
18
19 import numpy, numpy.random
20
21 have_scipy = check_dependency('scipy')
22 if have_scipy:
23     import scipy.stats
24
25    
26 class ParameterDist(object):
27
28     def __init__(self,**params):
29         self.params = params
30         self.dist_name = 'ParameterDist'
31    
32     def __repr__(self):
33         if len(self.params)==0:
34             return '%s()'% (self.dist_name,)
35         s = '%s('% (self.dist_name,)
36         for key in self.params:
37             s+='%s=%s,' % (key,str(self.params[key]))
38         return s[:-1]+')'
39
40     def next(self,n=1):
41         raise NotImplementedError('This is an abstract base class and cannot be used directly')
42
43     def from_stats(self,vals,bias=0.0,expand=1.0):
44         self.__init__(mean=numpy.mean(vals)+bias, std=numpy.std(vals)*expand)
45
46     def __eq__(self, o):
47         # should we track the state of the rng and return False if it is different between self and o?
48         if (type(self) == type(o) and
49             self.dist_name == o.dist_name and
50             self.params == o.params):
51             return True
52         else:
53             return False
54
55 class GammaDist(ParameterDist):
56     """
57     gamma.pdf(x,a,b) = x**(a-1)*exp(-x/b)/gamma(a)/b**a
58
59     Yields strictly positive numbers.
60     Generally the distribution is implemented by scipy.stats.gamma.pdf(x/b,a)/b
61     For more info, in ipython type:
62     >>> ? scipy.stats.gamma
63
64     """
65    
66     def __init__(self,mean=None,std=None,repr_mode='ms',**params):
67         """
68         repr_mode specifies how the dist is displayed,
69         either mean,var ('ms', the default) or a,b ('ab')
70         """
71         self.repr_mode = repr_mode
72         if 'm' in params and mean==None:
73             mean = params['m']
74         if 's' in params and std==None:
75             std = params['s']
76
77         # both mean and std not specified
78         if (mean,std)==(None,None):
79             if 'a' in params:
80                 a = params['a']
81             else:
82                 a = 1.0
83             if 'b' in params:
84                 b = params['b']
85             else:
86                 b = 1.0
87         else:
88             if mean==None:
89                 mean = 0.0
90             if std==None:
91                 std=1.0
92             a = mean**2/std**2
93             b = mean/a   
94         ParameterDist.__init__(self,a=a,b=b)
95         self.dist_name = 'GammaDist'
96
97     if have_scipy:   
98         def next(self,n=1):
99             return scipy.stats.gamma.rvs(self.params['a'],size=n)*self.params['b']
100     else:
101         def next(self,n=1):
102             raise Exception('Error scipy was not found at import time.  GammaDist realization disabled.')
103        
104     def mean(self):
105         return self.params['a']*self.params['b']
106
107     def std(self):
108         return self.params['b']*numpy.sqrt(self.params['a'])
109
110     def __repr__(self):
111         if self.repr_mode == 'ms':
112             return '%s(m=%f,s=%f)' % (self.dist_name,self.mean(),self.std())
113         else:
114             return '%s(a=%f,b=%f)' % (self.dist_name,self.params['a'],self.params['b'])
115        
116
117 class NormalDist(ParameterDist):
118     """
119     normal distribution with parameters
120     mean + std
121     
122     Generally the distribution is implemented
123     by scipy.stats.gamma.pdf(x/b,a)/b
124
125     For more info, in ipython type:
126     >>> ? scipy.stats.gamma
127
128     """
129    
130     def __init__(self,mean=0.0,std=1.0):
131         ParameterDist.__init__(self,mean=mean,std=std)
132         self.dist_name = 'NormalDist'
133        
134     def next(self,n=1):
135         return numpy.random.normal(loc=self.params['mean'],scale=self.params['std'],size=n)
136        
137
138 class UniformDist(ParameterDist):
139     """
140     uniform distribution with min,max
141     """
142
143     def __init__(self,min=0.0,max=1.0, return_type=float):
144         ParameterDist.__init__(self,min=min,max=max)
145         self.dist_name = 'UniformDist'
146         self.return_type = return_type
147        
148     def next(self,n=1):
149         vals = numpy.random.uniform(low=self.params['min'],high=self.params['max'],size=n)
150         if self.return_type != float:
151             vals = vals.astype(self.return_type)
152         return vals
153
154     def from_stats(self,vals,bias=0.0,expand=1.0):
155         mn = numpy.min(vals)
156         mx = numpy.max(vals)
157         center = 0.5*(mx+mn)+bias
158         hw = 0.5*(mx-mn)*expand
159         self.__init__(min=center-hw,max=center+hw)
Note: See TracBrowser for help on using the browser.