1 | """ |
---|
2 | Filename: general.py |
---|
3 | Author: Shannon Mason, shannon.mason@ecmwf.int |
---|
4 | Description: Common plotting functions. |
---|
5 | """ |
---|
6 | |
---|
7 | import pandas as pd |
---|
8 | import numpy as np |
---|
9 | |
---|
10 | #For loading and handling netCDF data |
---|
11 | import xarray as xr |
---|
12 | |
---|
13 | def format_time(ax, format_string="%H:%M", label='Time (UTC)'): |
---|
14 | """ |
---|
15 | Format axes for time coordinates. |
---|
16 | """ |
---|
17 | import matplotlib.dates as mdates |
---|
18 | ax.set_xticklabels(ax.xaxis.get_majorticklabels(), rotation=0, ha='center') |
---|
19 | ax.xaxis.set_major_formatter(mdates.DateFormatter(format_string)) |
---|
20 | ax.set_xlabel(label) |
---|
21 | |
---|
22 | |
---|
23 | def format_height(ax, scale=1.0e3, label='Height [km]'): |
---|
24 | """ |
---|
25 | Format axes for height coordinates. |
---|
26 | """ |
---|
27 | import matplotlib.ticker as ticker |
---|
28 | ticks_y = ticker.FuncFormatter(lambda x, pos: '${0:g}$'.format(x/scale)) |
---|
29 | ax.yaxis.set_major_formatter(ticks_y) |
---|
30 | ax.set_ylabel(label) |
---|
31 | |
---|
32 | |
---|
33 | def format_temperature(ax, scale='K'): |
---|
34 | """ |
---|
35 | Format axes for temperature coordinates. |
---|
36 | """ |
---|
37 | import matplotlib.ticker as ticker |
---|
38 | ticks_y = ticker.FuncFormatter(lambda x, pos: '${0:g}$'.format(x)) |
---|
39 | ax.yaxis.set_major_formatter(ticks_y) |
---|
40 | ymin, ymax = ax.get_ylim() |
---|
41 | ax.set_ylim(ymax, ymin) |
---|
42 | if scale == 'K': |
---|
43 | ax.set_ylabel('Temperature [K]') |
---|
44 | elif scale == 'C': |
---|
45 | ax.set_ylabel('Temperature [K]') |
---|
46 | else: |
---|
47 | Error("Scale must be either K or C") |
---|
48 | |
---|
49 | |
---|
50 | def format_pressure(ax, scale=100, label='Pressure [hPa]'): |
---|
51 | """ |
---|
52 | Format axes for pressure coordinates. |
---|
53 | """ |
---|
54 | import matplotlib.ticker as ticker |
---|
55 | ticks_p = ticker.FuncFormatter(lambda x, pos: '${0:g}$'.format(x/scale)) |
---|
56 | ax.yaxis.set_major_formatter(ticks_p) |
---|
57 | ax.set_ylabel(label) |
---|
58 | |
---|
59 | |
---|
60 | def format_latitude(ax): |
---|
61 | """ |
---|
62 | Format axes for latitude coordinates. |
---|
63 | """ |
---|
64 | import matplotlib.ticker as ticker |
---|
65 | latFormatter = ticker.FuncFormatter(lambda x, pos: "${:g}^\circ$S".format(-1*x) if x < 0 else "${:g}^\circ$N".format(x)) |
---|
66 | ax.xaxis.set_major_formatter(latFormatter) |
---|
67 | |
---|
68 | fancy_format_latitude = lambda x: r"${:.0f}^{{\circ}}$S".format(-1*x) if x < 0 else "${:.0f}^{{\circ}}$N".format(x) |
---|
69 | unfancy_format_latitude = lambda x: r"{:.0f}S".format(-1*x) if x < 0 else "{:.0f}N".format(x) |
---|
70 | |
---|
71 | def snap_to_axis(ax, ax_ref): |
---|
72 | """ |
---|
73 | Align subplot ax with the bounds of subplot ax_ref |
---|
74 | """ |
---|
75 | pos_ref = ax_ref.get_position() |
---|
76 | pos = ax.get_position() |
---|
77 | ax.set_position([pos_ref.x0, pos.y0, pos_ref.width, pos.height]) |
---|
78 | |
---|
79 | def get_figure_center(ax): |
---|
80 | bbox = ax.get_position() |
---|
81 | return (bbox.get_points()[0][0] + bbox.get_points()[1][0])/2 |
---|
82 | |
---|
83 | def get_figure_top(fig, ax, include_hspace=True): |
---|
84 | bbox = ax.get_position() |
---|
85 | if include_hspace: |
---|
86 | return bbox.get_points()[0][1] + fig.subplotpars.hspace |
---|
87 | else: |
---|
88 | return bbox.get_points()[0][1] |
---|
89 | |
---|
90 | def place_suptitle(fig, axes, suptitle, y=0.95, va='top'): |
---|
91 | center = get_figure_center(axes[0]) |
---|
92 | fig.suptitle(suptitle, ha='center', x=center, va=va, y=y) |
---|
93 | |
---|
94 | def add_subfigure_labels(axes, xloc=0.0, yloc=1.05, zorder=0, label_list=[], flatten_order='F'): |
---|
95 | if label_list == []: |
---|
96 | import string |
---|
97 | labels = string.ascii_lowercase |
---|
98 | else: |
---|
99 | labels = label_list |
---|
100 | |
---|
101 | for i, ax in enumerate(axes.flatten(order=flatten_order)): |
---|
102 | ax.text(xloc, yloc, "%s)" %(labels[i]), va='baseline', transform=ax.transAxes, fontweight='bold', zorder=zorder) |
---|