root/branches/parameter_space/test/unittests/test_basepopulation.py @ 1156

Revision 1156, 14.9 KB (checked in by apdavison, 13 months ago)

Merged neo_output branch into parameter_space branch

Line 
1from pyNN import common, errors, random, standardmodels, recording
2from pyNN.common import populations
3from pyNN.parameters import Sequence, ParameterSpace
4from nose.tools import assert_equal, assert_raises
5import numpy
6from mock import Mock, patch
7from pyNN.utility import assert_arrays_equal
8from pyNN import core
9from lazyarray import VectorizedIterable
10   
11builtin_open = open
12id_map = {'larry': 0, 'curly': 1, 'moe': 2, 'joe': 3, 'william': 4, 'jack': 5, 'averell': 6}
13
14
15class MockSimulator(object):
16    class MockState(object):
17        mpi_rank = 1
18        num_processes = 3
19    state = MockState()
20
21class MockStandardCell(standardmodels.StandardCellType):
22    recordable = ['v', 'spikes']
23    default_parameters = {'tau_m': 999.9, 'i_offset': 321.0, 'spike_times': Sequence([0,1,2]), 'foo': 33.3}
24    translations = {'tau_m': None, 'i_offset': None, 'spike_times': None, 'foo': None}
25    @classmethod
26    def translate(cls, parameters):
27        return parameters
28    @classmethod
29    def computed_parameters(cls):
30        return []
31
32class MockPopulation(populations.BasePopulation):
33    _simulator = MockSimulator
34    size = 13
35    all_cells = numpy.arange(100, 113)
36    _mask_local = numpy.array([0,1,0,1,0,1,0,1,0,1,0,1,0], bool)
37    local_cells = all_cells[_mask_local]
38    positions = numpy.arange(39).reshape((13,3)).T
39    label = "mock_population"
40    celltype = MockStandardCell({})
41    initial_values = {"foo": core.LazyArray(numpy.array((98, 100, 102)), shape=(3,))}
42    _assembly_class = populations.Assembly
43
44    def _get_view(self, selector, label=None):
45        return populations.PopulationView(self, selector, label)
46
47    def id_to_index(self, id):
48        if id.label in id_map:
49            return id_map[id.label]
50        else:
51            raise Exception("Invalid ID")
52       
53    def id_to_local_index(self, id):
54        if id.label in id_map:
55            global_index = id_map[id.label]
56            if global_index%2 == 1:
57                return global_index/2
58            else:
59                raise Exception("ID not on this node")
60        else:
61            raise Exception("Invalid ID")
62
63class MockID(object):
64    def __init__(self, label, parent):
65        self.label = label
66        self.parent = parent
67
68def test__getitem__int():
69    p = MockPopulation()
70    assert_equal(p[0], 100)
71    assert_equal(p[12], 112)
72    assert_raises(IndexError, p.__getitem__, 13)
73    assert_equal(p[-1], 112)
74   
75def test__getitem__slice():
76    orig_PV = populations.PopulationView
77    populations.PopulationView = Mock()
78    p = MockPopulation()
79    pv = p[3:9]
80    populations.PopulationView.assert_called_with(p, slice(3,9,None), None)
81    populations.PopulationView = orig_PV
82
83def test__getitem__list():
84    orig_PV = populations.PopulationView
85    populations.PopulationView = Mock()
86    p = MockPopulation()
87    pv = p[range(3,9)]
88    populations.PopulationView.assert_called_with(p, range(3,9), None)
89    populations.PopulationView = orig_PV
90
91def test__getitem__tuple():
92    orig_PV = populations.PopulationView
93    populations.PopulationView = Mock()
94    p = MockPopulation()
95    pv = p[(3,5,7)]
96    populations.PopulationView.assert_called_with(p, [3,5,7], None)
97    populations.PopulationView = orig_PV
98
99def test__getitem__invalid():
100    p = MockPopulation()
101    assert_raises(TypeError, p.__getitem__, "foo")
102
103def test_len():
104    p = MockPopulation()
105    assert_equal(len(p), MockPopulation.size)
106
107def test_iter():
108    p = MockPopulation()
109    itr = p.__iter__()
110    assert hasattr(itr, "next")
111    assert_equal(len(list(itr)), 6)
112
113def test_is_local():
114    p1 = MockPopulation()
115    p2 = MockPopulation()
116    id_local = MockID("curly", parent=p1)
117    id_nonlocal = MockID("larry", parent=p1)
118    assert p1.is_local(id_local)
119    assert not p1.is_local(id_nonlocal)
120    assert_raises(AssertionError, p2.is_local, id_local)
121   
122def test_all():
123    p = MockPopulation()
124    itr = p.all()
125    assert hasattr(itr, "next")
126    assert_equal(len(list(itr)), 13)
127
128def test_add():
129    p1 = MockPopulation()
130    p2 = MockPopulation()
131    assembly = p1 + p2
132    assert isinstance(assembly, populations.Assembly)
133    assert_equal(assembly.populations, [p1, p2])
134   
135def test_get_cell_position():
136    p = MockPopulation()
137    id = MockID("larry", parent=p)
138    assert_arrays_equal(p._get_cell_position(id), numpy.array([0,1,2]))
139    id = MockID("moe", parent=p)
140    assert_arrays_equal(p._get_cell_position(id), numpy.array([6,7,8]))
141   
142def test_set_cell_position():
143    p = MockPopulation()
144    id = MockID("larry", parent=p)
145    p._set_cell_position(id, numpy.array([100,101,102]))
146    assert_equal(p.positions[0,0], 100)
147    assert_equal(p.positions[0,1], 3)
148
149def test_get_cell_initial_value():
150    p = MockPopulation()
151    id = MockID("curly", parent=p)
152    assert_equal(p._get_cell_initial_value(id, "foo"), 98)
153
154def test_set_cell_initial_value():
155    p = MockPopulation()
156    id = MockID("curly", parent=p)
157    p._set_cell_initial_value(id, "foo", -1)
158    assert_equal(p._get_cell_initial_value(id, "foo"), -1)
159
160def test_nearest():
161    p = MockPopulation()
162    p.positions = numpy.arange(39).reshape((13,3)).T
163    assert_equal(p.nearest((0.0, 1.0, 2.0)), p[0])
164    assert_equal(p.nearest((3.0, 4.0, 5.0)), p[1])
165    assert_equal(p.nearest((36.0, 37.0, 38.0)), p[12])
166    assert_equal(p.nearest((1.49, 2.49, 3.49)), p[0])
167    assert_equal(p.nearest((1.51, 2.51, 3.51)), p[1])
168
169def test_sample():
170    orig_pv = populations.PopulationView
171    populations.PopulationView = Mock()
172    p = MockPopulation()
173    rng = Mock()
174    rng.permutation = Mock(return_value=numpy.array([7,4,8,12,0,3,9,1,2,11,5,10,6]))
175    pv = p.sample(5, rng=rng)
176    assert_arrays_equal(populations.PopulationView.call_args[0][1], numpy.array([7,4,8,12,0]))
177    populations.PopulationView = orig_pv
178
179def test_get_should_call_get_parameters():
180    p = MockPopulation()
181    p._get_parameters = Mock(return_value={'tau_m': Mock()})
182    p.get("tau_m")
183    p._get_parameters.assert_called_with("tau_m")
184
185def test_get_with_gather():
186    np_orig = MockPopulation._simulator.state.num_processes
187    rank_orig = MockPopulation._simulator.state.mpi_rank
188    gd_orig = recording.gather_dict
189    MockPopulation._simulator.state.num_processes = 2
190    MockPopulation._simulator.state.mpi_rank =  0
191    def mock_gather_dict(D): # really hacky
192        assert isinstance(D[0], (list, numpy.ndarray))
193        D[1] = [i-1 for i in D[0]] + [D[0][-1] + 1]
194        return D
195    recording.gather_dict = mock_gather_dict
196   
197    p = MockPopulation()
198    ps = Mock()
199    ps.evaluate = Mock(return_value=numpy.arange(11.0, 23.0, 2.0))
200    p._get_parameters = Mock(return_value={'tau_m': ps})
201    assert_arrays_equal(p.get("tau_m", gather=True),
202                        numpy.arange(10.0, 23.0))
203    MockPopulation._simulator.state.num_processes = np_orig
204    MockPopulation._simulator.state.mpi_rank = rank_orig
205    recording.gather_dict = gd_orig
206
207def test_set():
208    p = MockPopulation()
209    p._set_parameters = Mock()
210    p.set(tau_m=43.21)
211    p._set_parameters.assert_called_with(
212        ParameterSpace({'tau_m': 43.21}, p.celltype.get_schema(), size=p.size))
213
214def test_set_invalid_type():
215    p = MockPopulation()
216    assert_raises(errors.InvalidParameterValueError, p.set, foo={})
217    assert_raises(errors.InvalidParameterValueError, p.set, foo='bar')
218
219def test_set_with_list():
220    p = MockPopulation()
221    p._set_parameters = Mock()
222    p.set(spike_times=range(10))
223    p._set_parameters.assert_called_with(
224        ParameterSpace({'spike_times': range(10)}, p.celltype.get_schema(), size=p.size))
225   
226def test_tset_with_numeric_values():
227    p = MockPopulation()
228    p._set_parameters = Mock()
229    tau_m = numpy.linspace(10.0, 20.0, num=p.size)
230    p.tset("tau_m", tau_m)
231    assert_arrays_equal(p._set_parameters.call_args[0][0]['tau_m'][p._mask_local],
232                        tau_m[p._mask_local])
233
234def test_tset_with_array_values():
235    p = MockPopulation()
236    p._set_parameters = Mock()
237    spike_times = [Sequence(numpy.linspace(i, 100.0+i, 10)) for i in range(p.size)]
238    p.tset("spike_times", spike_times)
239    param = p._set_parameters.call_args[0][0]['spike_times']
240    assert_equal(param.shape[0], len(spike_times))
241    assert_arrays_equal(param[p._mask_local],
242                        numpy.array(spike_times)[p._mask_local])
243   
244def test_tset_invalid_dimensions_2D():
245    """Population.tset(): If the size of the valueArray does not match that of the Population, should raise an InvalidDimensionsError."""
246    p = MockPopulation()
247    array_in = numpy.array([[0.1,0.2,0.3],[0.4,0.5,0.6]])
248    assert_raises(errors.InvalidDimensionsError, p.tset, 'i_offset', array_in)
249
250def test_tset_invalid_dimensions_1D():
251    p = MockPopulation()
252    tau_m = numpy.linspace(10.0, 20.0, num=p.size+1)
253    assert_raises(errors.InvalidDimensionsError, p.tset, "tau_m", tau_m)
254
255
256class MockRandDistr(VectorizedIterable):
257    def next(self, n):
258        return numpy.arange(n)
259
260def test_rset():
261    """Population.rset()"""
262    # test should assume MPI with use of mask_local
263    p = MockPopulation()
264    p._set_parameters = Mock()
265    rd = MockRandDistr()
266    rnums = numpy.arange(p.size)
267    p.rset("foo", rd)
268    call_args = p._set_parameters.call_args
269    assert_arrays_equal(call_args[0][0]['foo'].evaluate(), rnums)
270
271def test_rset_with_native_rng():
272    p = MockPopulation()
273    p._native_rset = Mock()
274    rd = Mock()
275    rd.rng = random.NativeRNG()
276    p.rset('tau_m', rd)
277    p._native_rset.assert_called_with('tau_m', rd)
278
279def test_initialize():
280    p = MockPopulation()
281    p.initial_values = {}
282    p._set_initial_value_array = Mock()
283    p.initialize('v', -65.0)
284    assert_equal(p.initial_values['v'].evaluate(simplify=True), -65.0)
285    p._set_initial_value_array.assert_called_with('v', -65.0)   
286
287def test_initialize_random_distribution():
288    p = MockPopulation()
289    p.initial_values = {}
290    p._set_initial_value_array = Mock()
291    class MockRandomDistribution(random.RandomDistribution):
292        def next(self, n, mask_local):
293            return 42*numpy.ones(n)[mask_local]
294    p.initialize('v', MockRandomDistribution())
295    assert_arrays_equal(p.initial_values['v'].evaluate(simplify=True), 42*numpy.ones(p.local_size))
296    #p._set_initial_value_array.assert_called_with('v', 42*numpy.ones(p.size))
297
298def test_can_record():
299    p = MockPopulation()
300    p.celltype = MockStandardCell({})
301    assert p.can_record('v')
302    assert not p.can_record('foo')
303   
304def test_record_with_single_variable():
305    p = MockPopulation()
306    p.recorder = Mock()
307    p.record('v')
308    meth, args, kwargs = p.recorder.method_calls[0]
309    variables, id_arr = args
310    assert_equal(meth, 'record')
311    assert_equal(variables, 'v')
312    assert_arrays_equal(id_arr, p.all_cells)
313
314def test_record_with_multiple_variables():
315    p = MockPopulation()
316    p.recorder = Mock()
317    p.record(['v', 'gsyn_exc', 'spikes'])
318    meth, args, kwargs = p.recorder.method_calls[0]
319    variables, id_arr = args
320    assert_equal(meth, 'record')
321    assert_equal(variables, ['v', 'gsyn_exc', 'spikes'])
322    assert_arrays_equal(id_arr, p.all_cells)
323   
324def test_record_v():
325    p = MockPopulation()
326    p.record = Mock()
327    p.record_v("arg1")
328    p.record.assert_called_with('v', "arg1")
329
330def test_record_gsyn():
331    p = MockPopulation()
332    p.record = Mock()
333    p.record_gsyn("arg1")
334    p.record.assert_called_with(['gsyn_exc', 'gsyn_inh'], "arg1")
335
336def test_printSpikes():
337    p = MockPopulation()
338    p.recorder = Mock()
339    p.record_filter = "filter"
340    p.printSpikes("file", "gather", "compatible_output")
341    meth, args, kwargs = p.recorder.method_calls[0]
342    assert_equal(meth, 'write')
343    assert_equal(args, ("spikes", "file", "gather", "filter"))
344   
345def test_getSpikes():
346    p = MockPopulation()
347    p.recorder = Mock()
348    p.record_filter = "filter"
349    p.getSpikes("gather", "compatible_output")
350    meth, args, kwargs = p.recorder.method_calls[0]
351    assert_equal(meth, 'get')
352    assert_equal(args, ("spikes", "gather", "filter", False))
353
354def test_print_v():
355    p = MockPopulation()
356    p.recorder = Mock()
357    p.record_filter = "filter"
358    p.print_v("file", "gather", "compatible_output")
359    meth, args, kwargs = p.recorder.method_calls[0]
360    assert_equal(meth, 'write')
361    assert_equal(args, ("v", "file", "gather", "filter"))
362   
363def test_get_v():
364    p = MockPopulation()
365    p.recorder = Mock()
366    p.record_filter = "filter"
367    p.get_v("gather", "compatible_output")
368    meth, args, kwargs = p.recorder.method_calls[0]
369    assert_equal(meth, 'get')
370    assert_equal(args, ("v", "gather", "filter", False))
371   
372def test_print_gsyn():
373    p = MockPopulation()
374    p.recorder = Mock()
375    p.record_filter = "filter"
376    p.print_gsyn("file", "gather", "compatible_output")
377    meth, args, kwargs = p.recorder.method_calls[0]
378    assert_equal(meth, 'write')
379    assert_equal(args, (["gsyn_exc", "gsyn_inh"], "file", "gather", "filter"))
380   
381def test_get_gsyn():
382    p = MockPopulation()
383    p.recorder = Mock()
384    p.record_filter = "filter"
385    p.get_gsyn("gather", "compatible_output")
386    meth, args, kwargs = p.recorder.method_calls[0]
387    assert_equal(meth, 'get')
388    assert_equal(args, (["gsyn_exc", "gsyn_inh"], "gather", "filter", False))
389   
390def test_get_spike_counts():
391    p = MockPopulation()
392    p.recorder = Mock()
393    p.get_spike_counts("gather")
394    meth, args, kwargs = p.recorder.method_calls[0]
395    assert_equal(meth, 'count')
396    assert_equal(args, ("spikes", "gather", None))
397   
398def test_mean_spike_count():
399    orig_rank = MockPopulation._simulator.state.mpi_rank
400    MockPopulation._simulator.state.mpi_rank = 0
401    p = MockPopulation()
402    p.recorder = Mock()
403    p.recorder.count = Mock(return_value={0: 2, 1: 5})
404    assert_equal(p.mean_spike_count(), 3.5)
405    MockPopulation._simulator.state.mpi_rank = orig_rank
406
407def test_mean_spike_count_on_slave_node():
408    orig_rank = MockPopulation._simulator.state.mpi_rank
409    MockPopulation._simulator.state.mpi_rank = 1
410    p = MockPopulation()
411    p.recorder = Mock()
412    p.recorder.count = Mock(return_value={0: 2, 1: 5})
413    assert p.mean_spike_count() is numpy.NaN
414    MockPopulation._simulator.state.mpi_rank = orig_rank
415   
416def test_inject():
417    p = MockPopulation()
418    cs = Mock()
419    p.inject(cs)
420    meth, args, kwargs = cs.method_calls[0]
421    assert_equal(meth, "inject_into")
422    assert_equal(args, (p,))
423
424def test_inject_into_invalid_celltype():
425    p = MockPopulation()
426    p.celltype.injectable = False
427    assert_raises(TypeError, p.inject, Mock())
428
429def test_save_positions():
430    import os
431    orig_rank = MockPopulation._simulator.state.mpi_rank
432    MockPopulation._simulator.state.mpi_rank = 0
433    p = MockPopulation()
434    p.all_cells = numpy.array([34, 45, 56, 67])
435    p.positions = numpy.arange(12).reshape((4,3)).T
436    output_file = Mock()
437    p.save_positions(output_file)
438    assert_arrays_equal(output_file.write.call_args[0][0],
439                        numpy.array([[34, 0, 1, 2], [45, 3, 4, 5], [56, 6, 7, 8], [67, 9, 10, 11]]))
440    assert_equal(output_file.write.call_args[0][1], {'population': p.label})
441    # arguably, the first column should contain indices, not ids.
442    MockPopulation._simulator.state.mpi_rank = orig_rank
Note: See TracBrowser for help on using the browser.