Changeset 859 for trunk

Show
Ignore:
Timestamp:
12/10/10 16:41:57 (2 years ago)
Author:
apdavison
Message:

More unit tests, and some minor changes in common.

Location:
trunk
Files:
6 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/common.py

    r857 r859  
    471471         
    472472        if hasattr(self, "_get_array"): 
    473             values = self._get_array(parameter_name) 
     473            values = self._get_array(parameter_name).tolist() 
    474474        else: 
    475475            values = [getattr(cell, parameter_name) for cell in self]  # list or array? 
    476476         
    477477        if gather == True and num_processes() > 1: 
    478             all_values = { rank(): values } 
    479             all_index = { rank(): self.local_cells.tolist()} 
    480             all_values = recording.gather_dict(all_values) 
    481             all_index  = recording.gather_dict(all_index) 
     478            all_values  = { rank(): values } 
     479            all_indices = { rank(): self.local_cells.tolist()} 
     480            all_values  = recording.gather_dict(all_values) 
     481            all_indices = recording.gather_dict(all_indices) 
    482482            if rank() == 0: 
    483                 values = reduce(operator.add, all_values.values()) 
    484                 index  = reduce(operator.add, all_index.values()) 
    485             idx    = argsort(index) 
     483                values  = reduce(operator.add, all_values.values()) 
     484                indices = reduce(operator.add, all_indices.values()) 
     485            idx    = numpy.argsort(indices) 
    486486            values = numpy.array(values)[idx] 
    487487        return values 
     
    912912        if isinstance(id, IDMixin): 
    913913            if not self.first_id <= id <= self.last_id: 
    914                 raise IndexError("id should be in the range [%d,%d], actually %d" % (self.first_id, self.last_id, id)) 
     914                raise ValueError("id should be in the range [%d,%d], actually %d" % (self.first_id, self.last_id, id)) 
     915            return int(id - self.first_id)  # this assumes ids are consecutive 
    915916        else: 
    916917            id = numpy.array(id, IDMixin) 
    917918            if (self.first_id > id.min()) or (self.last_id < id.max()): 
    918                 raise IndexError("ids should be in the range [%d,%d], actually [%d, %d]" % (self.first_id, self.last_id, id.min(), id.max())) 
    919         return id - self.first_id  # this assumes ids are consecutive 
     919                raise ValueError("ids should be in the range [%d,%d], actually [%d, %d]" % (self.first_id, self.last_id, id.min(), id.max())) 
     920            return (id - self.first_id).astype(int)  # this assumes ids are consecutive 
    920921 
    921922    def id_to_local_index(self, id): 
     
    11611162        result[:,1:4] = self.positions.T  
    11621163        if rank() == 0: 
    1163             file.write(result, {'population' : self.label}) 
     1164            file.write(result, {'assembly' : self.label}) 
    11641165            file.close() 
    11651166 
     
    14131414        # it is arguable whether functions operating on the set of weights 
    14141415        # should be put here or in an external module. 
     1416        weights = self.getWeights(format='list', gather=True) 
     1417        if min is None: 
     1418            min = weights.min() 
     1419        if max is None: 
     1420            max = weights.max() 
    14151421        bins = numpy.linspace(min, max, nbins+1) 
    1416         return numpy.histogram(self.getWeights(format='list', gather=True), bins)  # returns n, bins 
     1422        return numpy.histogram(weights, bins, new=True)  # returns n, bins 
    14171423 
    14181424    def describe(self, template='projection_default.txt', engine='default'): 
  • trunk/test/unittests/test_assembly.py

    r853 r859  
    1  
     1from pyNN import common 
    22from pyNN.common import Assembly, BasePopulation 
    33from nose.tools import assert_equal, assert_raises 
     
    1818    _mask_local = numpy.arange(10)%2 == 1 
    1919    initialize = Mock() 
     20    positions = numpy.arange(3*size).reshape(3,size) 
    2021    def describe(self, template='abcd', engine=None): 
    2122        if template is None: 
     
    4546    a = Assembly(p1, p2, label="test") 
    4647    assert_equal(a.size, p1.size + p2.size) 
     48 
     49def test_positions_property(): 
     50    p1 = MockPopulation() 
     51    p2 = MockPopulation() 
     52    a = Assembly(p1, p2, label="test") 
     53    assert_arrays_equal(a.positions, numpy.concatenate((p1.positions, p2.positions), axis=1)) 
    4754 
    4855def test__len__(): 
     
    103110    a1 += a2 
    104111    assert_equal(a1.populations, [p1, p2, p2, p3]) 
    105      
     112 
     113def test_add_invalid_object(): 
     114    p1 = MockPopulation() 
     115    p2 = MockPopulation() 
     116    a = Assembly(p1, p2) 
     117    assert_raises(TypeError, a.__add__, 42) 
     118    assert_raises(TypeError, a.__iadd__, 42) 
     119 
    106120def test_initialize(): 
    107121    p1 = MockPopulation() 
     
    164178    assert_arrays_equal(a._mask_local, numpy.append(p1._mask_local, (p2._mask_local, p3._mask_local))) 
    165179    assert_arrays_equal(a.local_cells, a.all_cells[a._mask_local]) 
     180 
     181def test_save_positions(): 
     182    import os 
     183    orig_rank = common.rank 
     184    common.rank = lambda: 0 
     185    p1 = MockPopulation() 
     186    p2 = MockPopulation() 
     187    p1.all_cells = numpy.array([34, 45]) 
     188    p2.all_cells = numpy.array([56, 67]) 
     189    p1.positions = numpy.arange(0,6).reshape((2,3)).T 
     190    p2.positions = numpy.arange(6,12).reshape((2,3)).T 
     191    a = Assembly(p1, p2, label="test") 
     192    output_file = Mock() 
     193    a.save_positions(output_file) 
     194    assert_arrays_equal(output_file.write.call_args[0][0], 
     195                        numpy.array([[34, 0, 1, 2], [45, 3, 4, 5], [56, 6, 7, 8], [67, 9, 10, 11]])) 
     196    assert_equal(output_file.write.call_args[0][1], {'assembly': a.label}) 
     197    # arguably, the first column should contain indices, not ids. 
     198    common.rank = orig_rank 
  • trunk/test/unittests/test_basepopulation.py

    r858 r859  
    158158    MockPopulation.__iter__ = orig_iter 
    159159 
    160 #def test_get_with_gather(): 
    161 #    np_orig = common.num_processes 
    162 #    rank_orig = common.rank 
    163 #    gd_orig = common.recording.gather_dict 
    164 #    common.num_processes = lambda: 2 
    165 #    common.rank = 0 
    166 #    common.recording.gather_dict = Mock(return_value={0:  
    167 #     
    168 #    p = MockPopulation() 
    169 #    p._get_array = Mock(return_value=numpy.arange(10.0, 23.0, 1.0)) 
    170 #    p.get("tau_m") 
    171 #     
    172 #    common.num_processes = np_orig 
    173 #    common.rank = rank_orig 
    174 #    common.recording.gather_dict = gd_orig 
     160def test_get_with_gather(): 
     161    np_orig = common.num_processes 
     162    rank_orig = common.rank 
     163    gd_orig = common.recording.gather_dict 
     164    common.num_processes = lambda: 2 
     165    common.rank = lambda: 0 
     166    def mock_gather_dict(D): # really hacky 
     167        assert isinstance(D[0], list) 
     168        D[1] = [i-1 for i in D[0]] + [D[0][-1] + 1] 
     169        return D 
     170    common.recording.gather_dict = mock_gather_dict 
     171     
     172    p = MockPopulation() 
     173    p._get_array = Mock(return_value=numpy.arange(11.0, 23.0, 2.0)) 
     174    assert_arrays_equal(p.get("tau_m", gather=True), 
     175                        numpy.arange(10.0, 23.0)) 
     176     
     177    common.num_processes = np_orig 
     178    common.rank = rank_orig 
     179    common.recording.gather_dict = gd_orig 
    175180 
    176181def test_set_from_dict(): 
  • trunk/test/unittests/test_population.py

    r850 r859  
    66 
    77 
    8 class MockID(object): 
    9     def __init__(self, i, parent): 
    10         self.label = str(i) 
    11         self.parent = parent 
     8class MockID(int, common.IDMixin): 
     9    def __init__(self, n): 
     10        int.__init__(n) 
     11        common.IDMixin.__init__(self) 
    1212    def get_parameters(self): 
    1313        return {} 
     
    1616    recorder_class = Mock() 
    1717    initialize = Mock() 
    18     first_id = 999 
    19     last_id = 7777 
    2018     
    2119    def _create_cells(self, cellclass, cellparams, size): 
    22         self.all_cells = numpy.array([MockID(i, self) for i in range(size)], MockID) 
    23         self._mask_local = numpy.arange(size)%5==3 
     20        self.all_cells = numpy.array([MockID(i) for i in range(999, 999+size)], MockID) 
     21        self._mask_local = numpy.arange(size)%5==3 # every 5th cell, starting with the 4th, is on this node 
     22        self.first_id = self.all_cells[0] 
     23        self.last_id = self.all_cells[-1] 
    2424 
    2525class MockStandardCell(standardmodels.StandardCellType): 
     
    7777    assert_arrays_equal(p.cell, p.all_cells) 
    7878 
    79 #def test_id_to_index(): 
     79def test_id_to_index(): 
     80    p = MockPopulation(11, MockStandardCell) 
     81    assert isinstance(p[0], common.IDMixin) 
     82    assert_equal(p.id_to_index(p[0]), 0) 
     83    assert_equal(p.id_to_index(p[10]), 10) 
    8084 
    81 # test id_to_local_index 
     85def test_id_to_index_with_array(): 
     86    p = MockPopulation(11, MockStandardCell) 
     87    assert isinstance(p[0], common.IDMixin) 
     88    assert_arrays_equal(p.id_to_index(p.all_cells[3:9:2]), numpy.arange(3,9,2)) 
     89 
     90def test_id_to_index_with_populationview(): 
     91    p = MockPopulation(11, MockStandardCell) 
     92    assert isinstance(p[0], common.IDMixin) 
     93    view = p[3:7] 
     94    assert isinstance(view, common.PopulationView) 
     95    assert_arrays_equal(p.id_to_index(view), numpy.arange(3,7)) 
     96 
     97def test_id_to_index_with_invalid_id(): 
     98    p = MockPopulation(11, MockStandardCell) 
     99    assert isinstance(p[0], common.IDMixin) 
     100    assert_raises(ValueError, p.id_to_index, MockID(p.last_id+1)) 
     101    assert_raises(ValueError, p.id_to_index, MockID(p.first_id-1)) 
     102     
     103def test_id_to_index_with_invalid_ids(): 
     104    p = MockPopulation(11, MockStandardCell) 
     105    assert_raises(ValueError, p.id_to_index, [MockID(p.first_id-1)] + p.all_cells[0:3].tolist()) 
     106 
     107def test_id_to_local_index(): 
     108    orig_np = common.num_processes 
     109    common.num_processes = lambda: 5 
     110    p = MockPopulation(11, MockStandardCell) 
     111    # every 5th cell, starting with the 4th, is on this node. 
     112    assert_equal(p.id_to_local_index(p[3]), 0) 
     113    assert_equal(p.id_to_local_index(p[8]), 1) 
     114     
     115    common.num_processes = lambda: 1 
     116    # only one node 
     117    assert_equal(p.id_to_local_index(p[3]), 3) 
     118    assert_equal(p.id_to_local_index(p[8]), 8) 
     119    common.num_processes = orig_np 
     120 
     121def test_id_to_local_index_with_invalid_id(): 
     122    orig_np = common.num_processes 
     123    common.num_processes = lambda: 5 
     124    p = MockPopulation(11, MockStandardCell) 
     125    # every 5th cell, starting with the 4th, is on this node. 
     126    assert_raises(ValueError, p.id_to_local_index, p[0]) 
     127    common.num_processes = orig_np 
    82128 
    83129# test structure property 
     
    114160    assert p.positions[0,0] != 99.9 
    115161 
     162def test_position_generator(): 
     163    p = MockPopulation(11, MockStandardCell) 
     164    assert_arrays_equal(p.position_generator(0), p.positions[:,0]) 
     165    assert_arrays_equal(p.position_generator(10), p.positions[:,10]) 
     166    assert_arrays_equal(p.position_generator(-1), p.positions[:,10]) 
     167    assert_arrays_equal(p.position_generator(-11), p.positions[:,0]) 
     168    assert_raises(IndexError, p.position_generator, 11) 
     169    assert_raises(IndexError, p.position_generator, -12) 
     170 
    116171# test describe method 
    117172def test_describe(): 
  • trunk/test/unittests/test_populationview.py

    r850 r859  
    7474# test initial values property 
    7575 
    76 # test structure property 
     76def test_structure_property(): 
     77    p = MockPopulation(11, MockStandardCell) 
     78    mask = slice(3,9,2) 
     79    pv = common.PopulationView(parent=p, selector=mask) 
     80    assert_equal(pv.structure, p.structure) 
    7781 
    7882# test positions property 
  • trunk/test/unittests/test_projection.py

    r850 r859  
    44import numpy 
    55import os 
     6from pyNN.utility import assert_arrays_equal 
    67 
    78orig_rank = common.rank 
     
    190191    os.remove(filename) 
    191192 
     193def test_weight_histogram_with_args(): 
     194    p1 = MockPopulation() 
     195    p2 = MockPopulation() 
     196    prj = common.Projection(p1, p2, method=Mock()) 
     197    prj.getWeights = Mock(return_value=numpy.array(range(10)*42)) 
     198    n, bins = prj.weightHistogram(min=0.0, max=9.0, nbins=10) 
     199    assert_equal(n.size, 10) 
     200    assert_equal(bins.size, n.size+1) 
     201    assert_arrays_equal(n, 42*numpy.ones(10)) 
     202    assert_equal(n.sum(), 420) 
     203    assert_arrays_equal(bins, numpy.arange(0.0, 9.1, 0.9)) 
     204 
     205def test_weight_histogram_no_args(): 
     206    p1 = MockPopulation() 
     207    p2 = MockPopulation() 
     208    prj = common.Projection(p1, p2, method=Mock()) 
     209    prj.getWeights = Mock(return_value=numpy.array(range(10)*42)) 
     210    n, bins = prj.weightHistogram(nbins=10) 
     211    assert_equal(n.size, 10) 
     212    assert_equal(bins.size, n.size+1) 
     213    assert_arrays_equal(n, 42*numpy.ones(10)) 
     214    assert_equal(n.sum(), 420) 
     215    assert_arrays_equal(bins, numpy.arange(0.0, 9.1, 0.9)) 
     216 
    192217def test_describe(): 
    193218    p1 = MockPopulation()