root/trunk/src/parameters.py

Revision 355, 25.2 kB (checked in by apdavison, 2 months ago)

Can now (again) combine ParameterRanges and ParameterDists inside a ParameterSpace, e.g.::

for sub_parameter_space in parameter_space.iter_inner(copy=True):
    for parameter_set in sub_parameter_space.realize_dists(n=2, copy=True):
        print parameter_set.pretty()

Before, this would raise an Exception because iter_inner() would return a ParameterSet even if there were ParameterDists inside it.

Presumably, this used to work before I split the ParameterSet into separate ParameterSet and ParameterSpace classes, so really I'm just fixing a bug I introduced then.

Line 
1 """
2 NeuroTools.parameters
3 =====================
4
5 A module for dealing with model parameters.
6
7 Classes
8 -------
9
10 Parameter
11 ParameterRange - for specifying a list of possible values for a given parameter.
12 ParameterSet   - for representing/managing hierarchical parameter sets.
13 ParameterTable - a sub-class of ParameterSet that can represent a table of parameters.
14 ParameterSpace - a collection of ParameterSets, representing multiple points in
15                  parameter space.
16
17 Functions
18 ---------
19
20 nesteddictwalk    - Walk a nested dict structure, using a generator.
21 nesteddictflatten - Return a flattened version of a nested dict structure.
22 string_table      - Convert a table written as a multi-line string into a dict of dicts.
23
24 """
25
26 import urllib, copy, warnings, numpy, numpy.random  # to be replaced with srblib
27 from urlparse import urlparse
28 from NeuroTools import check_dependency
29 from NeuroTools.random import ParameterDist, GammaDist, UniformDist, NormalDist
30
31
32 def isiterable(x):
33     return (hasattr(x,'__iter__') and not isinstance(x, basestring))
34
35 def contains_instance(collection, cls):
36     return any(isinstance(o, cls) for o in collection)
37
38 # Deprecated: should use ParameterSet('file:///path/to/filename')
39 #def read_parameters(filename):
40 #    """Read parameters from a text file."""
41 #    parameters = {}
42 #    f = open(filename, 'r')
43 #    exec(f) in globals(), parameters
44 #    f.close()
45 #    return parameters
46
47 def nesteddictwalk(d, separator='.'):
48     """
49     Walk a nested dict structure, using a generator.
50     
51     Composite keys are created by joining each key to the key of the parent dict
52     using `separator`.
53     """
54     for key1,value1 in d.items():
55         if isinstance(value1, dict):
56             for key2, value2 in nesteddictwalk(value1, separator):  # recurse into subdict
57                     yield "%s%s%s" % (key1, separator, key2), value2
58         else:
59             yield key1, value1
60            
61 def nesteddictflatten(d, separator='.'):
62     """
63     Return a flattened version of a nested dict structure.
64     
65     Composite keys are created by joining each key to the key of the parent dict
66     using `separator`.
67     """
68     flatd = {}
69     for k,v in nesteddictwalk(d, separator):
70         flatd[k] = v
71     return flatd
72
73
74 # --- Parameters, and ranges and distributions of them -------------------------
75
76
77 class Parameter(object):
78
79     def __init__(self, value, units=None, name=""):
80         self.name  = name
81         self.value = value
82         self.units = units
83         self.type  = type(value)
84
85     def __repr__(self):
86         s = "%s = %s" % (self.name, self.value)
87         if self.units is not None:
88             s += " %s" % self.units
89         return s
90
91
92 class ParameterRange(Parameter):
93     """
94     A class for specifying a list of possible values for a given parameter.
95     
96     The value must be an iterable. It acts like a Parameter, but .next() can be
97     called to iterate through the values
98     """
99
100     def __init__(self, value, units=None, name="", shuffle=False):
101         if not isiterable(value):
102             raise TypeError,"A ParameterRange value must be iterable"
103         Parameter.__init__(self, value.__iter__().next(), units, name)
104         self._iter_values = value.__iter__()
105         if shuffle:
106             self._values = numpy.random.permutation(value)
107         else:
108             self._values = value
109    
110     def __repr__(self):
111         units_str = ''
112         if self.units:
113             units_str = ', units="%s"' % self.units
114         return 'ParameterRange(%s%s)' % (self._values.__repr__(), units_str)
115
116     def __iter__(self):
117         self._iter_values = self._values.__iter__()
118         return self._iter_values
119
120     def next(self):
121         self._value = self._iter_values.next()
122         return self._value
123
124     def __len__(self):
125         return len(self._values)
126
127     def __eq__(self, o):
128         if (type(self) == type(o) and
129             self.name == o.name and
130             self._values == o._values and
131             self.units == o.units):
132             return True
133         else:
134             return False
135        
136
137 # --- ParameterSet and subclasses ----------------------------------------------
138
139 class ParameterSet(dict):
140     """
141     A class to manage hierarchical parameter sets.
142     
143     Usage example:
144     >>> sim_params = ParameterSet({'dt': 0.1, 'tstop': 1000.0})
145     >>> exc_cell_params = ParameterSet("http://neuralensemble.org/svn/NeuroTools/example.params")
146     >>> inh_cell_params = ParameterSet({'tau_m': 15.0, 'cm': 0.5})
147     >>> network_params = ParameterSet({'excitatory_cells': exc_cell_params, 'inhibitory_cells': inh_cell_params})
148     >>> P = ParameterSet({'sim': sim_params, 'network': network_params})
149     >>> P.sim.dt
150     0.1
151     >>> P.network.inhibitory_cells.tau_m
152     15.0
153     >>> print P.pretty()
154     
155     """
156    
157     non_parameter_attributes = ['_url','label','names','parameters','flat','flatten','non_parameter_attributes']
158     invalid_names = ['parameters', 'names'] # should probably add dir(dict)
159     
160     @staticmethod
161     def read_from_str(str):
162         """
163         ParameterSet definition str should be a Python dict definition
164         string, containing objects of types int, float, str, list,
165         dict plus the classes defined in this module, `Parameter`,
166         `ParameterRange`, etc.  No other object types are allowed,
167         except the function url('some_url'), e.g.: { 'a' : {'A': 3,
168         'B': 4}, 'b' : [1,2,3], 'c' : 'hello world', 'd' :
169         url('http://example.com/my_cool_parameter_set') }
170
171         This is largely the JSON (www.json.org) format, but with
172         extra keywords in the Namespace such as ParameterRange, GammaDist, etc.
173
174         Python also supports specifying dictionaries as follows:
175
176         dict(x=1,y=2)
177
178         But usage of such un-JSON-ly idioms should likely be discouraged...
179         
180         """
181         global_dict = dict(url=ParameterSet,ParameterSet=ParameterSet)
182         global_dict.update(dict(ParameterRange=ParameterRange,
183                                 ParameterTable=ParameterTable,
184                                 GammaDist=GammaDist,
185                                 UniformDist=UniformDist,
186                                 NormalDist=NormalDist,
187                                 pi=numpy.pi))           
188         try:
189             D = eval(str, global_dict)
190         except SyntaxError:
191             raise SyntaxError("Invalid string for ParameterSet definition: %s" % str)
192         return D or {}
193    
194     @staticmethod
195     def check_validity(k):
196         if k in ParameterSet.invalid_names:
197             raise Exception("'%s' is not allowed as a parameter name." % k)
198    
199     def __init__(self, initialiser, label=None):
200        
201         def walk(d, label):
202             # Iterate through the dictionary `d`, replacing `dict`s by
203             # `ParameterSet` objects.
204             for k,v in d.items():
205                 ParameterSet.check_validity(k)
206                 if isinstance(v, ParameterSet):
207                     d[k] = v
208                 elif isinstance(v, dict):
209                     d[k] = walk(v, k)
210                 else:
211                     d[k] = v
212             return ParameterSet(d, label)
213        
214         self._url = None
215         if isinstance(initialiser, basestring): # url or str
216             try:
217                 # can't handle cases where authentication is required
218                 # should be rewritten using urllib2
219                 #scheme, netloc, path, \
220                 #        parameters, query, fragment = urlparse(initialiser)
221                 f = urllib.urlopen(initialiser)
222                 pstr = f.read()
223                 self._url = initialiser
224             except IOError:
225                 pstr = initialiser
226                 self._url = None
227             else:
228                 f.close()
229
230             initialiser = ParameterSet.read_from_str(pstr)
231        
232         # By this stage, `initialiser` should be a dict. Iterate through it,
233         # copying its contents into the current instance, and replacing dicts by
234         # ParameterSet objects.
235         if isinstance(initialiser, dict):
236             for k,v in initialiser.items():
237                 ParameterSet.check_validity(k)
238                 if isinstance(v, ParameterSet):
239                     self[k] = v
240                 elif isinstance(v, dict):
241                     self[k] = walk(v, k)
242                 else:
243                     self[k] = v
244         else:
245             raise TypeError("`initialiser` must be a `dict`, a `ParameterSet` object or a valid URL")
246                    
247         # Set the label
248         if hasattr(initialiser, 'label'):
249             self.label = label or initialiser.label # if initialiser was a ParameterSet, keep the existing label if the label arg is None
250         else:
251             self.label = label
252        
253         # Define some aliases, allowing, e.g.:
254         # for name, value in P.parameters():
255         # for name in P.names():
256         self.names = self.keys
257         self.parameters = self.items
258        
259     def flat(self):
260         __doc__ = nesteddictwalk.__doc__
261         return nesteddictwalk(self)
262    
263     def flatten(self):
264         __doc__ = nesteddictflatten.__doc__
265         return nesteddictflatten(self)
266                
267     def __getattr__(self, name):
268         """Allow accessing parameters using dot notation."""
269         try:
270             return self[name]
271         except KeyError:
272             return self.__getattribute__(name)
273    
274     def __setattr__(self, name, value):
275         """Allow setting parameters using dot notation."""
276         if name in self.non_parameter_attributes:
277             object.__setattr__(self, name, value)
278         else:
279             # should we check the parameter type hasn't changed?
280             self[name] = value
281
282     def __getitem__(self,name):
283         """ Modified get that detects dots '.' in the names and goes down the
284         nested tree to find it"""
285         split = name.split('.',1)
286         if len(split)==1:
287             return dict.__getitem__(self,name)
288         # nested get
289         return dict.__getitem__(self,split[0])[split[1]]
290
291     def __setitem__(self,name,value):
292         """ Modified set that detects dots '.' in the names and goes down the
293         nested tree to set it """
294
295         split = name.split('.',1)
296         if len(split)==1:
297             dict.__setitem__(self,name,value)
298         else:
299             # nested set
300             dict.__getitem__(self,split[0])[split[1]]=value
301    
302     # should __len__() be the usual dict length, or the flattened length? Probably the former for consistency with dicts
303     # can always use len(ps.flatten())
304   
305     # what about __contains__()? Should we drill down to lower levels in the hierarchy? I think so.
306   
307     def __getstate__(self):
308         """For pickling."""
309         return self
310    
311     def save(self, url=None, expand_urls=False):
312         """
313         Write the parameter set to a text file.
314         
315         The text file syntax is open to discussion. My idea is that it should be
316         valid Python code, preferably importable as a module.
317         
318         If `url` is `None`, try to save to `self._url` (if it is not `None`),
319         otherwise save to `url`.
320         """
321         # possible solution for HTTP PUT: http://inamidst.com/proj/put/put.py
322         if not url:
323             url = self._url
324         assert url != ''
325         if not self._url:
326             self._url = url
327         scheme, netloc, path, parameters, query, fragment = urlparse(url)
328         if scheme == 'file' or (scheme=='' and netloc==''):
329             f = open(path, 'w')
330             f.write(self.pretty(expand_urls=expand_urls))
331             f.close()
332         else:
333             if scheme:
334                 raise Exception("Saving using the %s protocol is not implemented" % scheme)
335             else:
336                 raise Exception("No protocol (http, ftp, etc) specified.")
337        
338     def pretty(self, indent='  ', expand_urls=False):
339         """
340         Return a unicode string representing the structure of the `ParameterSet`.
341         `eval`uating the string should recreate the object.
342         """
343         def walk(d, indent, ind_incr):
344             s = []
345             for k,v in d.items():
346                 if hasattr(v, 'items'):
347                     if expand_urls is False and hasattr(v, '_url') and v._url:
348                         s.append('%s"%s": url("%s"),' % (indent, k, v._url))
349                     else:
350                         s.append('%s"%s": {' % (indent, k))
351                         s.append(walk(v, indent+ind_incr,  ind_incr))
352                         s.append('%s},' % indent)
353                 elif isinstance(v, basestring):
354                     s.append('%s"%s": "%s",' % (indent, k, v))
355                 else: # what if we have a dict or ParameterSet inside a list? currently they are not expanded. Should they be?
356                     s.append('%s"%s": %s,' % (indent, k, v))
357             return '\n'.join(s)
358         return '{\n' + walk(self, indent, indent) + '\n}'
359
360     def tree_copy(self):
361         """ returns a copy of the ParameterSet tree structure.
362         Nodes are not copied, but re-referenced."""
363
364         tmp = ParameterSet({})
365         for key in self:
366             value = self[key]
367             if isinstance(value, ParameterSet):
368                 # recurse
369                 tmp[key]=value.tree_copy()
370             else:
371                 tmp[key]=value
372         if tmp._is_space():
373             tmp = ParameterSpace(tmp)
374         return tmp
375
376     def as_dict(self):
377         """ returns a copy of the ParameterSet tree structure
378         as a nested dictionary"""
379
380         tmp = {}
381        
382         for key in self:
383             value = self[key]
384             if isinstance(value, ParameterSet):
385                 # recurse
386                 tmp[key]=value.as_dict()
387             else:
388                 tmp[key]=value
389         return tmp
390
391     def __sub__(self, other):
392         """
393         Return the difference between this ParameterSet and another.
394         Not yet properly implemented.
395         """
396         self_keys = set(self)
397         other_keys = set(other)
398         intersection = self_keys.intersection(other_keys)
399         difference1 = self_keys.difference(other_keys)
400         difference2 = other_keys.difference(self_keys)
401         result1 = dict([(key, self[key]) for key in difference1])
402         result2 = dict([(key, other[key]) for key in difference2])
403         # Now need to check values for intersection....
404         for item in intersection:
405             if isinstance(self[item], ParameterSet):
406                 d1,d2 = self[item] - other[item]
407                 if d1:
408                     result1[item] = d1
409                 if d2:
410                     result2[item] = d2
411             elif self[item] != other[item]:
412                 result1[item] = self[item]
413                 result2[item] = other[item]
414         if len(result1) + len(result2) == 0:
415             assert self == other, "Error in ParameterSet.diff()"
416         return result1, result2
417    
418     def _is_space(self):
419         """
420         Checks for the presence of ParameterRanges or ParameterDists to
421         determine if this is a ParameterSet or a ParameterSpace.
422         """
423         for k,v in self.flat():
424             if isinstance(v, ParameterRange) or isinstance(v, ParameterDist):
425                 return True
426         return False
427    
428
429 class ParameterSpace(ParameterSet):
430     """A collection of ParameterSets, representing multiple points in
431     parameter space. Created by putting ParameterRange and/or ParameterDist
432     objects within a ParameterSet."""
433
434     def iter_range_key(self,range_key):
435         """ An iterator of the ParameterSpace which yields the
436         ParameterSet with the ParameterRange given by key replaced with
437         each of its values"""
438
439         tmp = self.tree_copy()
440         for val in self[range_key]:
441             tmp[range_key] = val
442             yield tmp
443
444     def iter_inner_range_keys(self,keys,copy=False):
445         """ An iterator of the ParameterSpace which yields
446         ParameterSets with all combinations of ParameterRange elements
447         which are given by the keys list
448
449         Note: each newly yielded value is one and the same object
450         so storing the returned values results in a collection
451         of many of the lastly yielded object.
452
453         copy=True causes each yielded object to be a newly
454         created object, but be careful because this is
455         spawning many dictionaries!
456
457         """
458         if len(keys)==0:
459             # return an iterator over 1 copy for modifying
460             yield self.tree_copy()
461             return
462
463         if not copy:
464             # recursively iterate over remaining keys
465             for tmp in self.iter_inner_range_keys(keys[1:]):
466                 # iterator over range of our present attention
467                 for val in self[keys[0]]:
468                     tmp[keys[0]]=val
469                     if not tmp._is_space():
470                         tmp = ParameterSet(tmp)
471                     yield tmp
472         else:
473             # Each yielded ParameterSet is a tree_copy of self
474
475             # recursively iterate over remaining keys
476             for tmp in self.iter_inner_range_keys(keys[1:]):
477                 # iterator over range of our present attention
478                 for val in self[keys[0]]:
479                     tmp_copy = tmp.tree_copy()
480                     tmp_copy[keys[0]]=val
481                     if not tmp_copy._is_space():
482                         tmp = ParameterSet(tmp) 
483                     yield tmp_copy
484            
485
486     def range_keys(self):
487         """ returns the list of keys for those elements which are ParameterRanges """
488         return [key for key,value in self.flat() if isinstance(value,ParameterRange)]
489
490
491     def iter_inner(self,copy=False):
492         """An iterator of the ParameterSpace which yields
493         ParameterSets with all combinations of ParameterRange elements"""
494
495         return self.iter_inner_range_keys(self.range_keys(),copy)
496        
497     def num_conditions(self):
498         """Returns the number of ParameterSets that will be returned by the
499         iter_inner() method."""
500         # Not properly tested
501         n = 1
502         for key in self.range_keys():
503             n *= len(self[key])
504         return n
505
506     def dist_keys(self):
507         """ returns the list of keys for those elements which are ParameterDists """
508         def is_or_contains_dist(value):
509             return isinstance(value, ParameterDist) or (
510                 isiterable(value) and contains_instance(value, ParameterDist))
511         return [key for key,value in self.flat() if is_or_contains_dist(value)]
512
513     def realize_dists(self,n=1,copy=False):
514         """For each ParameterDist, realize the distribution and yield the result
515
516         If copy==True, causes each yielded object to be a newly
517         created object, but be careful because this is
518         spawning many dictionaries!"""
519         def next(item, n):
520             if isinstance(item, ParameterDist):
521                 return item.next(n)
522             else:
523                 return [item]*n
524         # pre-generate random numbers
525         rngs = {}
526         for key in self.dist_keys():
527             if isiterable(self[key]):
528                 rngs[key] = [next(item, n) for item in self[key]]
529             else:
530                 rngs[key] = self[key].next(n)
531         # get a copy to fill in the rngs
532         if copy:
533             tmp = self.tree_copy()
534             for i in range(n):
535                 for key in rngs:
536                     if isiterable(self[key]):
537                         tmp[key] = [rngs[key][j][i] for j in range(len(rngs[key]))]
538                     else:
539                         tmp[key] = rngs[key][i]
540                 yield tmp.tree_copy()
541         else:
542             tmp = self.tree_copy()
543             for i in range(n):
544                 for key in rngs:
545                     if isiterable(self[key]):
546                         tmp[key] = [rngs[key][j][i] for j in range(len(rngs[key]))]
547                     else:
548                         tmp[key] = rngs[key][i]
549                 yield tmp
550
551     def parameter_space_dimension_labels(self):
552         """
553         returns the dimentions and labels of the keys for those elements which are ParameterRanges
554         range_keys are sorted to ensure same ordering each time.
555         """
556         range_keys = self.range_keys()
557         range_keys.sort()
558        
559         dim = []
560         label = []
561         for key in range_keys:
562             label.append(key)
563             dim.append(len(eval('self.'+key)))
564            
565         return dim,label
566
567     def parameter_space_index(self,current_experiment):
568         """
569         returns the index of the current experiment in the dimension of the parameter space
570         i.e. parameter space dimension: [2,3]
571         i.e. index: (1,0)
572
573         Example:
574
575         p = ParameterSet({})
576         p.b = ParameterRange([1,2,3])
577         p.a = ParameterRange(['p','y','t','h','o','n'])
578
579         results_dim, results_label = p.parameter_space_dimension_labels()
580
581         results = numpy.empty(results_dim)
582         for experiment in p.iter_inner():
583             index = p.parameter_space_index(experiment)
584             results[index] = 2.
585         
586         """
587         index = []
588         range_keys = self.range_keys()
589         range_keys.sort()
590         for key in range_keys:
591             value = eval('current_experiment.'+key)
592             try:
593                 value_index = list(eval('self.'+key)._values).index(value)
594             except ValueError:
595                 raise ValueError("The ParameterSet provided is not within the ParameterSpace")
596             index.append(value_index)
597         return tuple(index)
598        
599
600 def string_table(tablestring):
601     """Convert a table written as a multi-line string into a dict of dicts."""
602     tabledict = {}
603     rows = tablestring.strip().split('\n')
604     column_headers = rows[0].split()
605     for row in rows[1:]:
606         row = row.split()
607         row_header = row[0]
608         tabledict[row_header] = {}
609         for col_header,item in zip(column_headers[1:],row[1:]):
610             tabledict[row_header][col_header] = float(item)
611     return tabledict
612
613
614 class ParameterTable(ParameterSet):
615     """
616     A sub-class of ParameterSet that can represent a table of parameters.
617     
618     i.e., it is limited to one-level of nesting, and each sub-dict must have
619     the same keys. In addition to the possible initialisers for ParameterSet,
620     a ParameterTable can be initialised from a multi-line string, e.g.
621     
622         >>> pt = ParameterTable('''
623         ...     #       col1    col2    col3
624         ...     row1     1       2       3   
625         ...     row2     4       5       6
626         ...     row3     7       8       9
627         ... ''')
628         >>> pt.row2.col3
629         6.0
630         >>> pt.column('col1')
631         {'row1': 1.0, 'row2': 4.0, 'row3': 7.0}
632         >>> pt.transpose().col3.row2
633         6.0
634     
635     """
636    
637     non_parameter_attributes = ParameterSet.non_parameter_attributes + \
638                                ['row', 'rows', 'row_labels',
639                                 'column', 'columns', 'column_labels']
640    
641     def __init__(self, initialiser, label=None):
642         if isinstance(initialiser, basestring): # url or table string
643             tabledict = string_table(initialiser)
644             # if initialiser is a URL, string_table() should return an empty dict
645             # since URLs do not contain spaces.
646             if tabledict: # string table
647                 initialiser = tabledict
648         ParameterSet.__init__(self, initialiser, label)
649         # Now need to check that the contents actually define a table, i.e.
650         # two levels of nesting and each sub-dict has the same keys
651         self._check_is_table()
652        
653         self.rows = self.items
654         #self.rows.__doc__ = "Return a list of (row_label, row) pairs, as 2-tuples."""
655         self.row_labels = self.keys
656         #self.row_labels.__doc__ = "Return a list of row labels."
657         
658     def _check_is_table(self):
659         """
660         Checks that the contents actually define a table, i.e.
661         one level of nesting and each sub-dict has the same keys.
662         Raises an Exception is these requirements are violated.
663         """
664         # to be implemented
665         pass
666    
667     def row(self, row_label):
668         """Returns a ParameterSet object containing the requested row."""
669         return self[row_label]
670    
671     def column(self, column_label):
672         """Returns a ParameterSet object containing the requested column."""
673         col = {}
674         for row_label, row in self.rows():
675             col[row_label] = row[column_label]
676         return ParameterSet(col)
677    
678     def columns(self):
679         """Return a list of (column_label, column) pairs, as 2-tuples."""
680         return [(column_label, self.column(column_label)) for column_label in self.column_labels()]
681    
682     def column_labels(self):
683         """Return a list of column labels."""
684         sample_row = self[self.row_labels()[0]]
685         return sample_row.keys()
686    
687     def transpose(self):
688         """
689         Return a new `ParameterTable` object with the same data as the current
690         one but with rows and columns swapped.
691         """
692         new_table = ParameterTable({})
693         for column_label, column in self.columns():
694             new_table[column_label] = column
695         return new_table
696    
697     def table_string(self):
698         """
699         Returns the table as a string, suitable for being used as the
700         initialiser for a new `ParameterTable`.
701         """
702         # formatting could definitely be improved
703         column_labels = self.column_labels()
704         lines = [ "#\t " + "\t".join(column_labels) ]
705         for row_label, row in self.rows():
706             lines.append(row_label + "\t" + "\t".join(["%s" % row[col] for col in column_labels]))
707         return "\n".join(lines)
708
Note: See TracBrowser for help on using the browser.