Changeset 283

Show
Ignore:
Timestamp:
11/04/08 20:31:18 (2 months ago)
Author:
bruederle
Message:

Formatted, documented and marginally extended the functions and classes included in module 'plotting' so far. Added working unit tests for everything.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/plotting.py

    r250 r283  
    1  
    2 import numpy, pylab, sys 
    3 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 
    4 from matplotlib.figure import Figure 
    5 from matplotlib.lines import Line2D 
    61""" 
    7 Plotting helpers 
     2plotting.py 
     3 
     4Routines that make nice plotting with Matplotlib easier. 
    85""" 
    96 
    10 def pylab_params(fig_width_pt = 246.0, 
    11                 ratio = (numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default 
    12                 text_fontsize = 10 , tick_labelsize = 8): 
    13     """ 
    14     This functions calls some parameters to properly print figures for your papers. 
    15     See http://www.scipy.org/Cookbook/Matplotlib/UsingTex 
    16  
    17     fig_width_pt   # Get this from LaTeX using \showthe\columnwidth 
    18      
    19  
    20     ratio : ratio between the height and the width of the figure 
    21  
    22     """ 
    23  
    24     # TODO: include a conversion to cm cm=1/3 # inches per cm 
    25  
     7 
     8import numpy, sys 
     9numpy_version = numpy.__version__.split(".")[0:2] 
     10numpy_version = float(".".join(numpy_version)) 
     11 
     12 
     13# Check availability of pylab 
     14try : 
     15    import pylab 
     16    from matplotlib.figure import Figure 
     17    from matplotlib.lines import Line2D 
     18    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 
     19except ImportError: 
     20    MATPLOTLIB_ERROR = \ 
     21" ----------------- MATPLOTLIB Warning : ---------------------\n \ 
     22Matplolib NOT detected. The module NeuroTools.plotting depends on this package.\n \ 
     23Please install the Matplotlib package\n \ 
     24--> http://matplotlib.sourceforge.net/" 
     25    raise Exception(MATPLOTLIB_ERROR) 
     26 
     27 
     28 
     29######################################################## 
     30# UNIVERSAL FUNCTIONS AND CLASSES FOR NORMAL PYLAB USE # 
     31######################################################## 
     32 
     33 
     34 
     35def pylab_params(fig_width_pt=246.0, 
     36                ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default 
     37                text_fontsize=10, tick_labelsize=8, useTex=False): 
     38    """ 
     39    Returns a dictionary with a set of parameters that help to nicely format figures. 
     40    The return object can be used to update the pylab run command parameters dictionary 'pylab.rcParams'. 
     41 
     42    Inputs: 
     43        fig_width_pt   - Figure width in points. If you want to use your figure inside LaTeX, 
     44                         get this value from LaTeX using '\showthe\columnwidth'. 
     45        ratio          - Ratio between the height and the width of the figure. 
     46        text_fontsize  - Size of axes and in-pic text fonts. 
     47        tick_labelsize - Size of tick label font. 
     48        useTex         - Enables or disables the use of LaTeX for all labels and texts 
     49                         (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex). 
     50 
     51    """ 
    2652    inches_per_pt = 1.0/72.27               # Convert pt to inch 
    2753    fig_width = fig_width_pt*inches_per_pt  # width in inches 
    28     fig_height = fig_width*ratio      # height in inches 
     54    fig_height = fig_width*ratio            # height in inches 
    2955    fig_size =  [fig_width,fig_height] 
    3056 
    3157    params = { 
    32             'axes.labelsize': text_fontsize, 
    33             'text.fontsize': text_fontsize, 
    34             'xtick.labelsize': tick_labelsize, 
    35             'ytick.labelsize': tick_labelsize, 
    36             'text.usetex':False, ##True, ## problem with svg output resolved in latest matplotlib 
    37             'figure.figsize': fig_size} 
     58            'axes.labelsize'  : text_fontsize, 
     59            'text.fontsize'   : text_fontsize, 
     60            'xtick.labelsize' : tick_labelsize, 
     61            'ytick.labelsize' : tick_labelsize, 
     62            'text.usetex'     : useTex, 
     63            'figure.figsize'  : fig_size} 
     64             
    3865    return params 
    3966 
    4067 
    41 def set_frame(ax,boollist,linewidth=2): 
    42     assert len(boollist) == 4 
    43     bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') 
    44     left   = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
    45     top    = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
    46     right  = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
    47     # anti-aliased? 
    48     if boollist != [True,True,True,True]: 
    49         ax.set_frame_on(False) 
    50         for side,draw in zip([left,bottom,right,top],boollist): 
    51             if draw: 
    52                 ax.add_line(side) 
    53  
    54  
    55  
    56 ############################################################# 
    57 ## Utility function for the plots. Common to all the objects 
    58 ## They are called when we try to do plots 
    59 ############################################################# 
     68 
     69def set_pylab_params(fig_width_pt=246.0, 
     70                    ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default 
     71                    text_fontsize=10, tick_labelsize=8, useTex=False): 
     72    """ 
     73    Updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams'  
     74    in order to achieve nicely formatted figures. 
     75 
     76    Inputs: 
     77        fig_width_pt   - Figure width in points. If you want to use your figure inside LaTeX, 
     78                         get this value from LaTeX using '\showthe\columnwidth'. 
     79        ratio          - Ratio between the height and the width of the figure. 
     80        text_fontsize  - Size of axes and in-pic text fonts. 
     81        tick_labelsize - Size of tick label font. 
     82        useTex         - Enables or disables the use of LaTeX for all labels and texts 
     83                         (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex). 
     84 
     85    """ 
     86    pylab.rcParams.update(pylab_params(fig_width_pt=fig_width_pt, ratio=ratio, text_fontsize=text_fontsize, \ 
     87        tick_labelsize=tick_labelsize, useTex=useTex)) 
     88 
     89 
    6090 
    6191def get_display(display): 
    6292    """ 
    63     Return a pylab object with a plot() function to draw the plots. 
     93    Returns a pylab object with a plot() function to draw the plots. 
    6494     
    6595    Inputs: 
     
    74104    else: 
    75105        return display 
    76      
     106 
     107 
     108 
    77109def set_labels(subplot, xlabel, ylabel): 
    78110    """ 
    79     Function to put some labels on a plot 
     111    Defines the axis labels in a plot. 
    80112     
    81113    Inputs: 
     
    87119        subplot.xlabel(xlabel) 
    88120        subplot.ylabel(ylabel) 
    89     else
     121    elif hasattr(subplot, 'set_xlabel')
    90122        subplot.set_xlabel(xlabel) 
    91123        subplot.set_ylabel(ylabel) 
     124    else:  
     125        raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_label(...) does not provide labelling functions.') 
     126 
    92127 
    93128 
    94129def set_axis_limits(subplot, xmin, xmax, ymin, ymax): 
    95130    """ 
    96     Function to set the axis on a plot 
     131    Defines the axis limits in a plot. 
    97132     
    98133    Inputs: 
    99         subplot - the targeted plot 
     134        subplot     - the targeted plot 
    100135        xmin, xmax  - the limits of the x axis 
    101136        ymin, ymax  - the limits of the y axis 
     
    104139        subplot.xlim(xmin, xmax) 
    105140        subplot.ylim(ymin, ymax) 
    106     else
     141    elif hasattr(subplot, 'set_xlim')
    107142        subplot.set_xlim(xmin, xmax) 
    108143        subplot.set_ylim(ymin, ymax) 
    109  
    110  
    111  
    112  
    113  
    114  
     144    else:  
     145        raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_axis_limits(...) does not provide limit defining functions.') 
     146 
     147 
     148 
     149#################################################################### 
     150# SPECIAL PLOTTING FUNCTIONS AND CLASSES FOR SPECIFIC REQUIREMENTS # 
     151#################################################################### 
    115152 
    116153 
     
    121158    the same x-range. 
    122159    """ 
    123      
    124160    def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None, 
    125161                 scaling=('linear','linear')): 
     
    149185            for i in range(nrows): 
    150186                ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight]) 
    151                 set_frame(ax,[True,True,False,False]) 
     187                self.set_frame(ax,[True,True,False,False]) 
    152188                ax.xaxis.tick_bottom() 
    153189                ax.yaxis.tick_left() 
     
    170206        else: 
    171207            raise Exception("Invalid value for scaling parameter") 
    172      
     208 
    173209    def finalise(self): 
    174210        """Adjustments to be made after all panels have been plotted.""" 
     
    179215            ax.xaxis.set_ticklabels([]) 
    180216 
    181     def save(self,filename): 
    182         """Save/print the figure to file.""" 
     217    def save(self, filename): 
     218        """Saves/prints the figure to file. 
     219         
     220        Inputs: 
     221            filename - string specifying the filename where to save the data 
     222        """ 
    183223        self.finalise() 
    184224        self.canvas.print_figure(filename) 
    185      
     225 
    186226    def next_panel(self): 
     227        """Changes to next panel within figure.""" 
    187228        ax = self.axes[self._curr_panel] 
    188229        self._curr_panel += 1 
     
    191232        ax.plot1 = getattr(ax, self.plot_function) 
    192233        return ax 
    193          
    194     def panel(self,i): 
    195         """Return panel i.""" 
     234 
     235    def panel(self, i): 
     236        """Returns panel i.""" 
    196237        ax = self.axes[i] 
    197238        ax.plot1 = getattr(ax, self.plot_function) 
    198239        return ax 
    199      
     240 
     241    def set_frame(self, ax, boollist, linewidth=2): 
     242        """ 
     243        Defines frames for the chosen axis. 
     244 
     245        Inputs: 
     246            as        - the targeted axis 
     247            boollist  - a list  
     248            linewidth - the limits of the y axis 
     249        """ 
     250        assert type(boollist) in [list, numpy.ndarray] 
     251        assert len(boollist) == 4 
     252        if boollist != [True,True,True,True]: 
     253            bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') 
     254            left   = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     255            top    = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     256            right  = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     257            ax.set_frame_on(False) 
     258            for side,draw in zip([left,bottom,right,top],boollist): 
     259                if draw: 
     260                    ax.add_line(side) 
  • trunk/test/test_plotting.py

    r245 r283  
    22Unit tests for the NeuroTools.plotting module 
    33""" 
     4 
     5import unittest 
     6from NeuroTools import plotting 
     7import pylab 
     8import os 
     9 
     10class PylabParamsTest(unittest.TestCase): 
     11 
     12    def runTest(self): 
     13         
     14        # define arbitrary values 
     15        fig_width_pt =  123.4 
     16        ratio = 0.1234 
     17        text_fontsize = 10  
     18        tick_labelsize = 8 
     19        useTex = True 
     20 
     21        inches_per_pt = 1.0/72.27               # Convert pt to inch 
     22        fig_width = fig_width_pt*inches_per_pt  # width in inches 
     23        fig_height = fig_width*ratio            # height in inches 
     24 
     25        testDict = { 
     26            'axes.labelsize'  : text_fontsize, 
     27            'text.fontsize'   : text_fontsize, 
     28            'xtick.labelsize' : tick_labelsize, 
     29            'ytick.labelsize' : tick_labelsize, 
     30            'text.usetex'     : useTex, 
     31            'figure.figsize'  : [fig_width, fig_height]} 
     32 
     33        plotting.set_pylab_params(fig_width_pt=fig_width_pt, ratio=ratio, text_fontsize=text_fontsize, \ 
     34            tick_labelsize=tick_labelsize, useTex=useTex) 
     35        for k in testDict.keys(): 
     36            assert pylab.rcParams.has_key(k) 
     37            assert pylab.rcParams[k] == testDict[k] 
     38 
     39 
     40 
     41class GetDisplayTest(unittest.TestCase): 
     42 
     43    def runTest(self): 
     44         
     45        a = plotting.get_display(True) 
     46        assert a != None 
     47        a = plotting.get_display(False) 
     48        assert a == None 
     49        a = plotting.get_display(1234) 
     50        assert a == 1234 
     51 
     52 
     53 
     54class SetLabelsTest(unittest.TestCase): 
     55 
     56    def runTest(self): 
     57 
     58        f = plotting.get_display(True) 
     59        x = range(10) 
     60        p = pylab.plot(x) 
     61        plotting.set_labels(pylab, 'the x axis', 'the y axis') 
     62 
     63        # set up a SimpleMultiplot with arbitrary values 
     64        self.nrows = 1 
     65        self.ncolumns = 1 
     66        title = 'testMultiplot' 
     67        xlabel = 'testXlabel' 
     68        ylabel = 'testYlabel' 
     69        scaling = ('linear','log') 
     70        self.smt = plotting.SimpleMultiplot(nrows=self.nrows, ncolumns=self.ncolumns, title=title, xlabel=xlabel, ylabel=ylabel, scaling=scaling) 
     71        plotting.set_labels(self.smt.panel(0), 'the x axis', 'the y axis') 
     72 
     73 
     74 
     75 
     76class SetAxisLimitsTest(unittest.TestCase): 
     77 
     78    def runTest(self): 
     79 
     80        f = plotting.get_display(True) 
     81        x = range(10) 
     82        pylab.plot(x) 
     83        plotting.set_axis_limits(pylab, 0., 123., -123., 456.) 
     84 
     85        # set up a SimpleMultiplot with arbitrary values 
     86        self.nrows = 1 
     87        self.ncolumns = 1 
     88        title = 'testMultiplot' 
     89        xlabel = 'testXlabel' 
     90        ylabel = 'testYlabel' 
     91        scaling = ('linear','log') 
     92        self.smt = plotting.SimpleMultiplot(nrows=self.nrows, ncolumns=self.ncolumns, title=title, xlabel=xlabel, ylabel=ylabel, scaling=scaling) 
     93        plotting.set_axis_limits(self.smt.panel(0), 0., 123., -123., 456.) 
     94 
     95 
     96 
     97class SimpleMultiplotTest(unittest.TestCase): 
     98 
     99    def setUp(self): 
     100     
     101        # define arbitrary values 
     102        self.nrows = 4 
     103        self.ncolumns = 5 
     104        title = 'testMultiplot' 
     105        xlabel = 'testXlabel' 
     106        ylabel = 'testYlabel' 
     107        scaling = ('linear','log') 
     108        self.smt = plotting.SimpleMultiplot(nrows=self.nrows, ncolumns=self.ncolumns, title=title, xlabel=xlabel, ylabel=ylabel, scaling=scaling) 
     109 
     110 
     111 
     112class SimpleMultiplotSaveTest(SimpleMultiplotTest): 
     113 
     114    def runTest(self): 
     115     
     116        filename = "deleteme.png" 
     117        if os.path.exists(filename): os.remove(filename) 
     118        self.smt.save(filename) 
     119        assert os.path.exists(filename) 
     120        os.remove(filename) 
     121 
     122 
     123 
     124class SimpleMultiplotSetFrameTest(SimpleMultiplotTest): 
     125 
     126    def runTest(self): 
     127     
     128        numPanels = self.nrows * self.ncolumns 
     129        boollist = [True,False,False,True] 
     130        for i in range(numPanels): 
     131            ax_indexed = self.smt.panel(i) 
     132            ax_next = self.smt.next_panel() 
     133            assert ax_indexed == ax_next 
     134            self.smt.set_frame(ax_indexed,boollist,linewidth=4) 
     135 
     136 
     137 
     138if __name__ == "__main__": 
     139    unittest.main()