| | 54 | |
|---|
| | 55 | def set_frame(ax,boollist,linewidth=2): |
|---|
| | 56 | assert len(boollist) == 4 |
|---|
| | 57 | bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| | 58 | left = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| | 59 | top = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| | 60 | right = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| | 61 | # anti-aliased? |
|---|
| | 62 | if boollist != [True,True,True,True]: |
|---|
| | 63 | ax.set_frame_on(False) |
|---|
| | 64 | for side,draw in zip([left,bottom,right,top],boollist): |
|---|
| | 65 | if draw: |
|---|
| | 66 | ax.add_line(side) |
|---|
| | 67 | |
|---|
| | 68 | class SimpleMultiplot(object): |
|---|
| | 69 | """ |
|---|
| | 70 | A figure consisting of multiple panels, all with the same datatype and |
|---|
| | 71 | the same x-range. |
|---|
| | 72 | """ |
|---|
| | 73 | |
|---|
| | 74 | def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None, |
|---|
| | 75 | scaling=('linear','linear')): |
|---|
| | 76 | self.fig = Figure() |
|---|
| | 77 | self.canvas = FigureCanvas(self.fig) |
|---|
| | 78 | self.axes = [] |
|---|
| | 79 | self.all_panels = self.axes |
|---|
| | 80 | self.nrows = nrows |
|---|
| | 81 | self.ncolumns = ncolumns |
|---|
| | 82 | self.n = nrows*ncolumns |
|---|
| | 83 | self._curr_panel = 0 |
|---|
| | 84 | self.title = title |
|---|
| | 85 | topmargin = 0.06 |
|---|
| | 86 | rightmargin = 0.02 |
|---|
| | 87 | bottommargin = 0.1 |
|---|
| | 88 | leftmargin=0.1 |
|---|
| | 89 | panelsep = 0.05 |
|---|
| | 90 | panelheight = (1 - topmargin - bottommargin - (nrows-1)*panelsep)/nrows |
|---|
| | 91 | panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*panelsep)/ncolumns |
|---|
| | 92 | assert panelheight > 0 |
|---|
| | 93 | |
|---|
| | 94 | bottomlist = [bottommargin + i*panelsep + i*panelheight for i in range(nrows)] |
|---|
| | 95 | leftlist = [leftmargin + j*panelsep + j*panelwidth for j in range(ncolumns)] |
|---|
| | 96 | bottomlist.reverse() |
|---|
| | 97 | for j in range(ncolumns): |
|---|
| | 98 | for i in range(nrows): |
|---|
| | 99 | ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight]) |
|---|
| | 100 | set_frame(ax,[True,True,False,False]) |
|---|
| | 101 | ax.xaxis.tick_bottom() |
|---|
| | 102 | ax.yaxis.tick_left() |
|---|
| | 103 | self.axes.append(ax) |
|---|
| | 104 | if xlabel: |
|---|
| | 105 | self.axes[self.nrows-1].set_xlabel(xlabel) |
|---|
| | 106 | if ylabel: |
|---|
| | 107 | self.fig.text(0.5*leftmargin,0.5,ylabel, |
|---|
| | 108 | rotation='vertical', |
|---|
| | 109 | horizontalalignment='center', |
|---|
| | 110 | verticalalignment='center') |
|---|
| | 111 | if scaling == ("linear","linear"): |
|---|
| | 112 | self.plot_function = "plot" |
|---|
| | 113 | elif scaling == ("log", "log"): |
|---|
| | 114 | self.plot_function = "loglog" |
|---|
| | 115 | elif scaling == ("log", "linear"): |
|---|
| | 116 | self.plot_function = "semilogx" |
|---|
| | 117 | elif scaling == ("linear", "log"): |
|---|
| | 118 | self.plot_function = "semilogy" |
|---|
| | 119 | else: |
|---|
| | 120 | raise Exception("Invalid value for scaling parameter") |
|---|
| | 121 | |
|---|
| | 122 | def finalise(self): |
|---|
| | 123 | """Adjustments to be made after all panels have been plotted.""" |
|---|
| | 124 | # Turn off tick labels for all x-axes except the bottom one |
|---|
| | 125 | self.fig.text(0.5, 0.99, self.title, horizontalalignment='center', |
|---|
| | 126 | verticalalignment='top') |
|---|
| | 127 | for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]: |
|---|
| | 128 | ax.xaxis.set_ticklabels([]) |
|---|
| | 129 | |
|---|
| | 130 | def save(self,filename): |
|---|
| | 131 | """Save/print the figure to file.""" |
|---|
| | 132 | self.finalise() |
|---|
| | 133 | self.canvas.print_figure(filename) |
|---|
| | 134 | |
|---|
| | 135 | def next_panel(self): |
|---|
| | 136 | ax = self.axes[self._curr_panel] |
|---|
| | 137 | self._curr_panel += 1 |
|---|
| | 138 | if self._curr_panel >= self.n: |
|---|
| | 139 | self._curr_panel = 0 |
|---|
| | 140 | ax.plot1 = getattr(ax, self.plot_function) |
|---|
| | 141 | return ax |
|---|
| | 142 | |
|---|
| | 143 | def panel(self,i): |
|---|
| | 144 | """Return panel i.""" |
|---|
| | 145 | ax = self.axes[i] |
|---|
| | 146 | ax.plot1 = getattr(ax, self.plot_function) |
|---|
| | 147 | return ax |
|---|
| | 148 | |
|---|