Show
Ignore:
Timestamp:
04/18/08 14:18:44 (9 months ago)
Author:
apdavison
Message:

Added SimpleMultiplot class to the plotting module.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/src/plotting.py

    r106 r158  
    22import numpy #, pylab 
    33import sys 
    4  
     4from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 
     5from matplotlib.figure import Figure 
     6from matplotlib.lines import Line2D 
    57""" 
    68Plotting helpers 
     
    5052 
    5153 
     54 
     55def set_frame(ax,boollist,linewidth=2): 
     56    assert len(boollist) == 4 
     57    bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') 
     58    left   = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     59    top    = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     60    right  = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') 
     61    # anti-aliased? 
     62    if boollist != [True,True,True,True]: 
     63        ax.set_frame_on(False) 
     64        for side,draw in zip([left,bottom,right,top],boollist): 
     65            if draw: 
     66                ax.add_line(side) 
     67 
     68class SimpleMultiplot(object): 
     69    """ 
     70    A figure consisting of multiple panels, all with the same datatype and 
     71    the same x-range. 
     72    """ 
     73     
     74    def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None, 
     75                 scaling=('linear','linear')): 
     76        self.fig = Figure() 
     77        self.canvas = FigureCanvas(self.fig) 
     78        self.axes = [] 
     79        self.all_panels = self.axes 
     80        self.nrows = nrows 
     81        self.ncolumns = ncolumns 
     82        self.n = nrows*ncolumns 
     83        self._curr_panel = 0 
     84        self.title = title 
     85        topmargin = 0.06 
     86        rightmargin = 0.02 
     87        bottommargin = 0.1 
     88        leftmargin=0.1 
     89        panelsep = 0.05 
     90        panelheight = (1 - topmargin - bottommargin - (nrows-1)*panelsep)/nrows 
     91        panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*panelsep)/ncolumns 
     92        assert panelheight > 0 
     93         
     94        bottomlist = [bottommargin + i*panelsep + i*panelheight for i in range(nrows)] 
     95        leftlist = [leftmargin + j*panelsep + j*panelwidth for j in range(ncolumns)] 
     96        bottomlist.reverse() 
     97        for j in range(ncolumns): 
     98            for i in range(nrows): 
     99                ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight]) 
     100                set_frame(ax,[True,True,False,False]) 
     101                ax.xaxis.tick_bottom() 
     102                ax.yaxis.tick_left() 
     103                self.axes.append(ax) 
     104        if xlabel: 
     105            self.axes[self.nrows-1].set_xlabel(xlabel) 
     106        if ylabel: 
     107            self.fig.text(0.5*leftmargin,0.5,ylabel, 
     108                          rotation='vertical', 
     109                          horizontalalignment='center', 
     110                          verticalalignment='center') 
     111        if scaling == ("linear","linear"): 
     112            self.plot_function = "plot" 
     113        elif scaling == ("log", "log"): 
     114            self.plot_function = "loglog" 
     115        elif scaling == ("log", "linear"): 
     116            self.plot_function = "semilogx" 
     117        elif scaling == ("linear", "log"): 
     118            self.plot_function = "semilogy" 
     119        else: 
     120            raise Exception("Invalid value for scaling parameter") 
     121     
     122    def finalise(self): 
     123        """Adjustments to be made after all panels have been plotted.""" 
     124        # Turn off tick labels for all x-axes except the bottom one 
     125        self.fig.text(0.5, 0.99, self.title, horizontalalignment='center', 
     126                      verticalalignment='top') 
     127        for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]: 
     128            ax.xaxis.set_ticklabels([]) 
     129 
     130    def save(self,filename): 
     131        """Save/print the figure to file.""" 
     132        self.finalise() 
     133        self.canvas.print_figure(filename) 
     134     
     135    def next_panel(self): 
     136        ax = self.axes[self._curr_panel] 
     137        self._curr_panel += 1 
     138        if self._curr_panel >= self.n: 
     139            self._curr_panel = 0 
     140        ax.plot1 = getattr(ax, self.plot_function) 
     141        return ax 
     142         
     143    def panel(self,i): 
     144        """Return panel i.""" 
     145        ax = self.axes[i] 
     146        ax.plot1 = getattr(ax, self.plot_function) 
     147        return ax 
     148