Source code for ldcpy.plot

import calendar
import copy
import warnings

import cartopy
import cf_xarray as cf
import cmocean
import matplotlib as mpl
import numpy as np
import xarray as xr
import xrft
from cartopy import crs as ccrs
from cartopy.util import add_cyclic_point
from matplotlib import pyplot as plt
from pylab import flipud

from ldcpy import calcs as lm, util as lu
from ldcpy.convert import CalendarDateTime


[docs]def tex_escape(text): """ :param text: a plain text message :return: the message escaped to appear correctly in LaTeX """ # conv = { # '&': r'\&', # '%': r'\%', # '$': r'\$', # '#': r'\#', # '_': r'\_', # '{': r'\{', # '}': r'\}', # '^': r'\^{}', # '\\': r'\textbackslash{}', # '<': r'\textless{}', # '>': r'\textgreater{}', # } # regex = re.compile( # '|'.join(re.escape(str(key)) for key in sorted(conv.keys(), key=lambda item: -len(item))) # ) # return regex.sub(lambda match: conv[], text) return text
[docs]class calcsPlot(object): """ This class contains code to plot calcs in an xarray Dataset that has either 'lat' and 'lon' dimensions, or a 'time' dimension. """ def __init__( self, ds, varname, calc, sets, group_by=None, scale='linear', calc_type='raw', plot_type='spatial', transform='none', subset=None, approx_lat=None, approx_lon=None, lev=0, color='coolwarm', standardized_err=False, quantile=None, calc_ssim=False, contour_levs=24, short_title=False, axes_symmetric=False, legend_loc='upper right', vert_plot=False, tex_format=False, legend_offset=None, weighted=True, ): self._ds = ds # calc settings used in plot titles self._varname = varname self._sets = sets self._title_lat = None self._title_lon = None # Plot settings self._calc = calc self._group_by = group_by self._scale = scale self._calc_type = calc_type self._plot_type = plot_type self._subset = subset self._true_lat = approx_lat self._true_lon = approx_lon self._transform = transform self._lev = lev self._color = color self._short_title = short_title self._quantile = quantile self._calc_ssim = calc_ssim self._contour_levs = contour_levs self._axes_symmetric = axes_symmetric self._legend_loc = legend_loc self.vert_plot = vert_plot self._tex_format = tex_format self._legend_offset = legend_offset self._weighted = weighted
[docs] def verify_plot_parameters(self): if len(self._sets) < 2 and self._calc_type in [ 'diff', 'ratio', 'calc_of_diff', ]: raise ValueError(f'Must specify set2 for {self._calc_type} calc type') if self._plot_type in ['spatial'] and self._group_by is not None: raise ValueError(f'Cannot group by {self._group_by} in a non-time-series plot') if self._plot_type not in ['spatial'] and self._color != 'coolwarm': raise ValueError('Cannot change color scheme in a non-spatial plot') if self._plot_type in ['spatial'] and ( self._true_lat is not None or self._true_lon is not None ): raise ValueError('Cannot currently subset by latitude or longitude in a spatial plot') if self._lev != 0: # and 'lev' not in self._ds.dims: try: vert =['vertical'] except KeyError: vert = None if vert is None: raise ValueError('Cannot subset by lev (vertical dimension) in this dataset') if self._quantile is not None and self._calc != 'quantile': raise ValueError('Cannot change quantile value if calc is not quantile') if self._quantile is None and self._calc == 'quantile': raise ValueError('Must specify quantile value as argument') if self._calc in ['lag1', 'corr_lag1', 'mae_day_max'] and self._plot_type not in [ 'spatial', ]: raise ValueError(f'Cannot plot {self._calc} in a non-spatial plot')
[docs] def get_calcs(self, da): da_data = da da_data.attrs = da.attrs # lat/lon dim names are different for ocn and atm dd =[['latitude'][0]].dims ll = len(dd) if ll == 1: lat_dim = dd[0] lon_dim =['longitude'].dims[0] elif ll == 2: lat_dim = dd[0] lon_dim = dd[1] if self._plot_type in ['spatial']: calcs_da = lm.Datasetcalcs(da_data, ['time'], weighted=self._weighted) elif self._plot_type in ['time_series', 'periodogram', 'histogram']: calcs_da = lm.Datasetcalcs(da_data, [lat_dim, lon_dim], weighted=self._weighted) else: raise ValueError(f'plot type {self._plot_type} not supported') if self._calc_ssim and self._plot_type != 'spatial': warnings.warn( 'SSIM is only calculated for spatial plots, ignoring calc_ssim option', UserWarning ) raw_data = calcs_da.get_calc(self._calc, self._quantile, self._group_by) return raw_data
[docs] def get_plot_data(self, raw_data_1, raw_data_2=None): if self._calc_type == 'diff': plot_data = raw_data_1 - raw_data_2 plot_data.attrs = raw_data_1.attrs elif self._calc_type == 'ratio': plot_data = raw_data_2 / raw_data_1 plot_data.attrs = raw_data_1.attrs if hasattr(self._ds, 'units'): self._odds_positive.attrs['units'] = '' elif self._calc_type == 'raw' or self._calc_type == 'calc_of_diff': plot_data = raw_data_1 else: raise ValueError(f'calc_type {self._calc_type} not supported') if self._group_by is not None and self._calc not in [ 'standardized_mean', 'odds_positive', ]: plot_attrs = plot_data.attrs plot_data = plot_data.groupby(self._group_by).mean(dim='time') plot_data.attrs = plot_attrs if self._transform == 'none': pass elif self._transform == 'log': plot_attrs = plot_data.attrs plot_data = np.log10(plot_data) plot_data.attrs = plot_attrs else: raise ValueError(f'calc transformation {self._transform} not supported') return plot_data
[docs] def get_title(self, calc_name, c_name=None): if c_name is not None: das = f'{c_name}' else: das = f'{self._sets[0]}' if self._short_title is True: if self._plot_type == 'time_series': return '' else: return das if self._quantile is not None and calc_name == 'quantile': calc_full_name = f'{calc_name} {self._quantile}' else: calc_full_name = calc_name if self._transform == 'log': title = f'{self._varname}: log10 {calc_full_name}' else: title = f'{self._varname}: {calc_full_name}' if self._plot_type == 'spatial': title = f'{das}: {title}' if self._calc_type != 'raw': title = f'{title} {self._calc_type}' if self._group_by is not None: title = f'{title} by {self._group_by}' if self.title_lat is not None: if self.title_lon is not None: title = f'{title} at lat={self.title_lat:.2f}, lon={self.title_lon:.2f}' else: title = f'{title} at lat={self.title_lat:.2f}' elif self.title_lon is not None: title = f'{title} at lat={self.title_lon:.2f}' if self._subset is not None: title = f'{title} subset:{self._subset}' if self._plot_type == 'histogram': title = f'time-series histogram:{title}' elif self._plot_type == 'periodogram': title = f'periodogram:{title}' return title
def _label_offset( self, ax, ): fmt = ax.yaxis.get_major_formatter() ax.yaxis.offsetText.set_visible(False) set_label = ax.set_ylabel label = ax.get_ylabel() def update_label(event_axes): offset = fmt.get_offset() if offset == '': set_label('{}'.format(label)) else: set_label('{} ({})'.format(label, offset)) return ax.callbacks.connect('ylim_changed', update_label) ax.figure.canvas.draw() update_label(None) return
[docs] def spatial_plot(self, da_sets, titles): if self.vert_plot: nrows = int((da_sets.sets.size)) else: nrows = int((da_sets.sets.size + 1) / 2) if len(da_sets) == 1: ncols = 1 else: ncols = 2 if self._calc == 'zscore': ncols = 1 nrows = len(da_sets) if self.vert_plot: fig = plt.figure(dpi=300, figsize=(4.5, 2.5 * nrows)) plt.rcParams.update({'font.size': 10}) else: fig = plt.figure(dpi=300, figsize=(9, 2.5 * nrows)) plt.rcParams.update({'font.size': 10}) mymap = copy.copy('{self._color}')) mymap.set_under(color='black') mymap.set_over(color='white') mymap.set_bad(alpha=0) axs = {} psets = {} nan_inf_flag = 0 all_nan_flag = 0 cmax = [] cmin = [] # lat/lon could be 1 or 2d and have different names lon_coord_name = da_sets[0].cf.coordinates['longitude'][0] lat_coord_name = da_sets[0].cf.coordinates['latitude'][0] # is the lat/lon 1d or 2d (to do: set error if > 2) latdim = da_sets[0].cf[lon_coord_name].ndim central = 0.0 # might make this a parameter later if latdim == 2: # probably pop central = 300.0 for i in range(da_sets.sets.size): if self.vert_plot: axs[i] = plt.subplot( nrows, 1, i + 1, projection=ccrs.Robinson(central_longitude=central) ) else: axs[i] = plt.subplot( nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=central) ) axs[i].set_facecolor('#39ff14') # make data periodic if latdim == 2: ylon = da_sets[i][lon_coord_name] lon_sets = np.hstack((ylon, ylon[:, 0:1])) xlat = da_sets[i][lat_coord_name] lat_sets = np.hstack((xlat, xlat[:, 0:1])) cy_datas = add_cyclic_point(da_sets[i]) else: # 1d ylon = da_sets[i][lon_coord_name] lon_sets = np.hstack((ylon, ylon[0])) lat_sets = da_sets[i][lat_coord_name] cy_datas = add_cyclic_point(da_sets[i]) if np.isnan(cy_datas).any() or np.isinf(cy_datas).any(): nan_inf_flag = 1 if np.isnan(cy_datas).all(): all_nan_flag = 1 cyxr = xr.DataArray(data=cy_datas) if not np.isinf(cyxr).all(): cmin.append(np.min(cyxr.where(cyxr != -np.inf).min())) cmax.append(np.max(cyxr.where(cyxr != np.inf).max())) if latdim == 2: no_inf_data_set = np.nan_to_num(cyxr.astype(np.float32), nan=np.nan) else: ncyxr = cyxr.roll(dim_1=145) no_inf_data_set = np.nan_to_num(ncyxr.astype(np.float32), nan=np.nan) # casting to float32 from float64 using imshow prevents lots of tiny black dots from showing up in some plots with lots of # zeroes. See plot of probability of negative PRECT to see this in action. if latdim == 2: psets[i] = psets[i] = axs[i].pcolormesh( lon_sets, lat_sets, no_inf_data_set, transform=ccrs.PlateCarree(), cmap=mymap, ) else: psets[i] = axs[i].imshow( img=flipud(no_inf_data_set), transform=ccrs.PlateCarree(), cmap=mymap ) # psets[i] = axs[i].imshow( # img=flipud(no_inf_data_set), transform=ccrs.PlateCarree(), cmap=mymap # ) axs[i].set_global() # if we want to get the ssim if self._calc_ssim: axs[i].axis('off') plt.margins(0, 0) extent1 = axs[i].get_window_extent().transformed(fig.dpi_scale_trans.inverted()) axs[i].imshow plt.savefig(f'tmp_ssim{i+1}', bbox_inches=extent1, transparent=True, pad_inches=0) axs[i].axis('on') # may need to be modified for other components if latdim == 1: axs[i].coastlines() else: axs[i].add_feature( cartopy.feature.NaturalEarthFeature( 'physical', 'land', '110m', linewidth=0.5, edgecolor='black', facecolor='darkgray', ) ) axs[i].set_title(tex_escape(titles[i])) del cy_datas # end of for loopon plots if len(cmin) > 0: color_min = np.min(cmin) else: color_min = -0.1 if len(cmax) > 0: color_max = np.max(cmax) else: color_max = 0.1 if self._axes_symmetric: color_max_abs = max(abs(color_min), abs(color_max)) color_min = -1 * color_max_abs color_max = color_max_abs for i in range(len(psets)): psets[i].set_clim(color_min, color_max) pass # add colorbar if self.vert_plot is False: fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.95) cbs = [] if not all_nan_flag: cax = fig.add_axes([0.1, 0, 0.8, 0.05]) for i in range(len(psets)): cbs.append(fig.colorbar(psets[i], cax=cax, orientation='horizontal', shrink=0.95)) cbs[i].ax.set_title(f'{da_sets[i].units}') if self.vert_plot: cbs[i].ax.set_aspect(0.03) cbs[i].ax.set_anchor((0, 1.35 + 0.15 * (nrows - 1))) else: cbs[i].ax.set_aspect(0.03) if len(psets) > 2: cbs[i].ax.set_anchor((0, 1.35 + 0.15 * (nrows - 1))) else: cbs[i].ax.set_anchor((0.5, 1.35 + 0.15 * (nrows - 1))) cbs[i].ax.tick_params(labelsize=8, rotation=30) if nan_inf_flag: proxy = [ plt.Rectangle((0, 0), 1, 1, fc='#39ff14'), plt.Rectangle((0, 1), 2, 2, fc='#000000'), plt.Rectangle((0, 1), 2, 2, fc='#ffffff', edgecolor='black'), ] if self.vert_plot: plt.rcParams.update({'font.size': 8}) plt.legend( proxy, ['NaN', '-Inf', 'Inf'], loc='lower center', bbox_to_anchor=(0.51, -6), ncol=len(proxy), ) else: plt.rcParams.update({'font.size': 10}) if len(psets) > 2: plt.legend( proxy, ['NaN', '-Inf', 'Inf'], bbox_to_anchor=(0.672, 4), ncol=len(proxy), ) else: plt.legend( proxy, ['NaN', '-Inf', 'Inf'], bbox_to_anchor=(0.78, -2), ncol=len(proxy), ) else: fig.add_axes([0.1, 0, 0.8, 0.05]) proxy = [ plt.Rectangle((0, 0), 1, 1, fc='#39ff14'), plt.Rectangle((0, 1), 2, 2, fc='#000000'), plt.Rectangle((0, 1), 2, 2, fc='#ffffff', edgecolor='black'), ] plt.legend(proxy, ['NaN', '-Inf', 'Inf'], bbox_to_anchor=(0.87, 2), ncol=len(proxy)) plt.axis('off') if self._calc_ssim: import os import from skimage.metrics import structural_similarity as ssim for i in range(1, len(da_sets)): img1 ='tmp_ssim1.png') img2 ='tmp_ssim{i+1}.png') # ssim_val = ssim(img1, img2, multichannel=True) ssim_val = ssim( img1, img2, multichannel=True, gaussian_weights=True, use_sample_covariance=False, ) print(f' SSIM 1 & {i+1} = % 5.5f\n' % (ssim_val)) for i in range(len(da_sets) + 1): if os.path.exists(f'tmp_ssim{i}.png'): os.remove(f'tmp_ssim{i}.png')
[docs] def hist_plot(self, plot_data, title): fig, axs = mpl.pyplot.subplots(1, 1, sharey=True, tight_layout=True) sets = [] for set in plot_data.sets: sets.append(plot_data.sel(sets=set)) axs.hist(sets, if plot_data.units != '': mpl.pyplot.xlabel(tex_escape(f'{self._calc} ({plot_data.units})')) else: mpl.pyplot.xlabel(tex_escape(f'{self._calc}')) mpl.pyplot.title(tex_escape(title[0])) if self.vert_plot: plt.legend(loc=self._legend_loc, borderaxespad=1.0) plt.rcParams.update({'font.size': 16}) else: plt.rcParams.update({'font.size': 10}) if self._legend_loc is None: plt.legend(bbox_to_anchor=(1.05, 1), loc=self._legend_loc, borderaxespad=0.0) else: plt.legend( bbox_to_anchor=self._legend_offset, loc=self._legend_loc, borderaxespad=0.0 )
[docs] def periodogram_plot(self, plot_data, title): plt.figure() for j in range(plot_data.sets.size): dat = xrft.dft( xr.DataArray(plot_data[j].data - plot_data[j].data.mean()).chunk( (plot_data[j].data - plot_data[j].data.mean()).size ) ) i = (np.multiply(dat, np.conj(dat)) / dat.size).real i = np.log10(i[2 : int(dat.size / 2) + 1]) freqs = np.array(range(1, int(dat.size / 2))) / dat.size mpl.pyplot.plot(freqs, i, label=plot_data[j] if self.vert_plot: plt.legend(loc=self._legend_loc, borderaxespad=1.0) plt.rcParams.update({'font.size': 16}) else: plt.rcParams.update({'font.size': 10}) if self._legend_loc is None: plt.legend(bbox_to_anchor=(1.05, 1), loc=self._legend_loc, borderaxespad=0.0) else: plt.legend( bbox_to_anchor=self._legend_offset, loc=self._legend_loc, borderaxespad=0.0 ) mpl.pyplot.title(tex_escape(title[0])) mpl.pyplot.ylabel('Spectrum') mpl.pyplot.xlabel('Frequency')
[docs] def time_series_plot( self, da_sets, titles, ): """ time series plot """ group_string = 'time.year' xlabel = 'date' tick_interval = int(da_sets.size / da_sets.sets.size / 5) + 1 if da_sets.size / da_sets.sets.size == 1: tick_interval = 1 if self._group_by == 'time.dayofyear': group_string = 'dayofyear' xlabel = 'Day of Year' elif self._group_by == 'time.month': group_string = 'month' xlabel = 'Month' tick_interval = 1 elif self._group_by == 'time.year': group_string = 'year' xlabel = 'Year' elif self._group_by == '': group_string = 'day' xlabel = 'Day' if self._calc_type == 'diff': if da_sets.units != '': ylabel = f'{self._calc} ({da_sets.units}) diff' else: ylabel = f'{self._calc} diff' elif self._calc_type == 'ratio': ylabel = f'{self._calc} ratio' elif self._calc_type == 'calc_of_diff': if da_sets.units != '': ylabel = f'{self._calc} ({da_sets.units}) of diff' else: ylabel = f'{self._calc} of diff' else: if da_sets.units != '': ylabel = f'{self._calc} ({da_sets.units})' else: ylabel = f'{self._calc}' if self._transform == 'log': plot_ylabel = f'log10 {ylabel}' else: plot_ylabel = ylabel'default') plt.figure() if self.vert_plot: plt.rcParams.update({'font.size': 16}) else: plt.rcParams.update({'font.size': 10}) plt.rcParams.update( { 'text.usetex': self._tex_format, } ) for i in range(da_sets.sets.size): if self._group_by is not None: plt.plot( da_sets[i][group_string].data, da_sets[i], f'C{i}', label=f'{[i]}', ) ax = plt.gca() else: dtindex = da_sets[i].indexes['time'] c_d_time = [CalendarDateTime(item, '365_day') for item in dtindex] mpl.pyplot.plot(c_d_time, da_sets[i], f'C{i}', label=f'{[i]}') ax = plt.gca() for label in ax.get_xticklabels(): label.set_rotation(30) label.set_horizontalalignment('right') if self.vert_plot: if self._legend_offset is None: plt.legend(loc=self._legend_loc, borderaxespad=1.0) else: plt.legend( loc=self._legend_loc, borderaxespad=1.0, bbox_to_anchor=self._legend_offset ) else: plt.rcParams.update({'font.size': 10}) if self._legend_offset is None: plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) else: plt.legend(bbox_to_anchor=self._legend_offset, loc='upper left', borderaxespad=0.0) mpl.pyplot.ylabel(tex_escape(plot_ylabel)) mpl.pyplot.yscale(self._scale) self._label_offset(ax) mpl.pyplot.xlabel(tex_escape(xlabel)) if self._group_by is not None: mpl.pyplot.xticks( np.arange(min(da_sets[group_string]), max(da_sets[group_string]) + 1, tick_interval) ) if self._group_by == 'time.month': int_labels = plt.xticks()[0] month_labels = [ calendar.month_name[i] for i in int_labels if calendar.month_name[i] != '' ] unique_month_labels = list(dict.fromkeys(month_labels)) plt.gca().set_xticklabels(unique_month_labels) for label in ax.get_xticklabels(): label.set_rotation(30) label.set_horizontalalignment('right') # else: # mpl.pyplot.xticks( # pd.date_range( # np.datetime64(da_sets['time'].data[0]), # np.datetime64(da_sets['time'].data[-1]), # periods=int(da_sets['time'].size / tick_interval) + 1, # ) # ) mpl.pyplot.title(tex_escape(titles[0]))
[docs] def get_calc_label(self, calc, data): dd =[['latitude'][0]].dims ll = len(dd) if ll == 1: lat_dim = dd[0] lon_dim =['longitude'].dims[0] elif ll == 2: lat_dim = dd[0] lon_dim = dd[1] # Get special calc names if self._short_title is False: if calc == 'zscore': zscore_cutoff = lm.Datasetcalcs( (data), ['time'], weighted=self._weighted ).get_single_calc('zscore_cutoff') percent_sig = lm.Datasetcalcs( (data), ['time'], weighted=self._weighted ).get_single_calc('zscore_percent_significant') calc_name = f'{calc}: cutoff {zscore_cutoff[0]:.2f}, % sig: {percent_sig:.2f}' elif calc == 'mean' and self._plot_type == 'spatial' and self._calc_type == 'raw': if self._weighted: a1_data = ( lm.Datasetcalcs(data, ['time'], weighted=self._weighted) .get_calc(calc) .cf.weighted('area') .mean() .data.compute() ) else: a1_data = ( lm.Datasetcalcs(data, ['time'], weighted=self._weighted) .get_calc(calc) .mean() .data.compute() ) print(a1_data) # check for NANs # indices = ~np.isnan(a1_data) # if weights is not None: # weights = weights[indices] # a2_data = np.average( # a1_data[indices], # axis=0, # weights=weights, # ).compute() # o_wt_mean = np.nanmean(a2_data) calc_name = f'{calc} = {a1_data:.2f}' elif calc == 'pooled_var_ratio': pooled_sd = np.sqrt( lm.Datasetcalcs((data), ['time'], weighted=self._weighted).get_single_calc( 'pooled_variance' ) ) d = calc_name = f'{calc}: pooled SD = {d:.2f}' elif calc == 'ann_harmonic_ratio': p = lm.Datasetcalcs((data), ['time'], weighted=self._weighted).get_single_calc( 'annual_harmonic_relative_ratio_pct_sig' ) calc_name = f'{calc}: % sig = {p:.2f}' elif self._plot_type == 'spatial': if self._weighted: a1_data = ( lm.Datasetcalcs(data, ['time'], weighted=self._weighted) .get_calc(calc) .cf.weighted('area') .mean() .data.compute() ) else: a1_data = ( lm.Datasetcalcs(data, ['time'], weighted=self._weighted) .get_calc(calc) .mean() .data.compute() ) calc_name = f'{calc} = {a1_data:.2f}' elif self._plot_type == 'time_series': if self._weighted: a1_data = ( lm.Datasetcalcs(data, [lat_dim, lon_dim], weighted=self._weighted) .get_calc(calc) .mean() .data.compute() ) else: a1_data = ( lm.Datasetcalcs(data, [lat_dim, lon_dim], weighted=self._weighted) .get_calc(calc) .mean() .data.compute() ) calc_name = f'{calc} = {a1_data:.2f}' else: calc_name = calc return calc_name else: return ''
[docs]def plot( ds, varname, calc, sets, group_by=None, scale='linear', calc_type='raw', plot_type='spatial', transform='none', subset=None, lat=None, lon=None, lev=0, color='coolwarm', quantile=None, start=None, end=None, calc_ssim=False, short_title=False, axes_symmetric=False, legend_loc='upper right', vert_plot=False, tex_format=False, legend_offset=None, weighted=True, ): """ Plots the data given an xarray dataset Parameters ========== ds : xarray.Dataset The input dataset varname : str The name of the variable to be plotted calc : str The name of the calc to be plotted (must match a property name in the Datasetcalcs class in ldcpy.plot, for more information about the available calcs see ldcpy.Datasetcalcs) Acceptable values include: - ns_con_var - ew_con_var - mean - std - variance - prob_positive - prob_negative - odds_positive - zscore - mean_abs - mean_squared - rms - sum - sum_squared - corr_lag1 - quantile - lag1 - standardized_mean - ann_harmonic_ratio - pooled_variance_ratio sets : list <str> The labels of the dataset to gather calcs from group_by : str how to group the data in time series plots. Valid groupings: - - time.dayofyear - time.month - time.year scale : str, optional time-series y-axis plot transformation. (default "linear") Valid options: - linear - log calc_type : str, optional The type of operation to be performed on the calcs. (default 'raw') Valid options: - raw: the unaltered calc values - diff: the difference between the calc values in the first set and every other set - ratio: the ratio of the calc values in (2nd, 3rd, 4th... sets/1st set) - calc_of_diff: the calc value computed on the difference between the first set and every other set plot_type : str , optional The type of plot to be created. (default 'spatial') Valid options: - spatial: a plot of the world with values at each lat and lon point (takes the mean across the time dimension) - time-series: A time-series plot of the data (computed by taking the mean across the lat and lon dimensions) - histogram: A histogram of the time-series data transform : str, optional data transformation. (default 'none') Valid options: - none - log subset : str, optional subset of the data to gather calcs on (default None). Valid options: - first5: the first 5 days of data - DJF: data from the months December, January, February - MAM: data from the months March, April, May - JJA: data from the months June, July, August - SON: data from the months September, October, November lat : float, optional The latitude of the data to gather calcs on (default None). lon : float , optional The longitude of the data to gather calcs on (default None). lev : float, optional The level of the data to gather calcs on (used if plotting from a 3d data set), (default 0). color : str, optional The color scheme for spatial plots, (default 'coolwarm'). see for more options quantile : float, optional A value between 0 and 1 required if calc="quantile", corresponding to the desired quantile to gather, (default 0.5). start : int, optional A value between 0 and the number of time slices indicating the start time of a subset, (default None). end : int, optional A value between 0 and the number of time slices indicating the end time of a subset, (default None) calc_ssim : bool, optional Whether or not to calculate the ssim (structural similarity index) between two plots (only applies to plot_type = 'spatial'), (default False) short_title: bool, optional If True, use a shortened title in the plot output (default False). axes_symmetric: bool, optional Whether or not to make the colorbar axes symmetric about zero (used in a spatial plot) (default False) legend_loc: str, optional The location to put the legend in a time-series plot in single-column format (plot_type = "time_series", vert_plot=True) (default "upper right") vert_plot: bool, optional If true, forces plots into a single column format and enlarges text. (default False) tex_format: bool, optional Whether to interpret all plot output strings as latex formatting (default False) legend_offset: 2-tuple, optional The x- and y- offset of the legend. Moves the corner of the legend specified by legend_loc to the specified location specified (where (0,0) is the bottom left corner of the plot and (1,1) is the top right corner). Only affects time-series, histogram, and periodogram plots. Returns ======= out : None """ mp = calcsPlot( ds, varname, calc, sets, group_by, scale, calc_type, plot_type, transform, subset, lat, lon, lev, color, quantile, calc_ssim=calc_ssim, legend_loc=legend_loc, axes_symmetric=axes_symmetric, short_title=short_title, vert_plot=vert_plot, tex_format=tex_format, legend_offset=legend_offset, weighted=weighted, ) plt.rcParams.update( { 'text.usetex': tex_format, } ) mp.verify_plot_parameters() # Subset data (by var and collection) dss = [] # update when new release of cf_xarray is released (won't need to do this - just trying to avoid # an uneeded arror message for now) if 'bounds' in ds['time'].attrs.keys(): ds['time'].attrs.pop('bounds') # if varname == 'T': # work around for cf_xarray (until new tag that # # includes issue 130 updated to main on 1/27/21) # ds.T.attrs['standard_name'] = 'tt' # if 'collection' in ds[varname].dims: # if sets is not None: # for set in sets: # d =['tt'].sel(collection=set) # d.coords["cell_area"] = ds.coords["cell_area"] # dss.append(d) # else: # d =['tt'] # d.coords["cell_area"] = ds.coords["cell_area"] # dss.append(['tt']) # # else: if 'collection' in ds[varname].dims: if sets is not None: for set in sets: dss.append(ds[varname].sel(collection=set)) else: dss.append(ds[varname]) subsets = [] if sets is not None: for i in range(len(sets)): subsets.append(lu.subset_data(dss[i], subset, lat, lon, lev, start, end)) subsets[i].attrs = dss[i].attrs subsets[i].attrs['cell_measures'] = 'area: cell_area' # Acquire raw calc values datas = [] if calc_type in ['calc_of_diff']: if subsets is not None: for i in range(1, len(subsets)): datas.append(subsets[0] - subsets[i]) datas[i - 1].attrs = subsets[0].attrs else: if subsets is not None: for i in range(len(subsets)): datas.append(subsets[i]) raw_calcs = [] for d in datas: raw_calcs.append(mp.get_calcs(d)) # get lat/lon coordinate names: if ds.data_type == 'pop': lon_coord_name = datas[0].cf[datas[0].cf.coordinates['longitude'][0]].dims[1] lat_coord_name = datas[0].cf[datas[0].cf.coordinates['latitude'][0]].dims[0] else: lat_coord_name = datas[0].cf[datas[0].cf.coordinates['latitude'][0]].dims[0] lon_coord_name = datas[0].cf[datas[0].cf.coordinates['longitude'][0]].dims[0] # Get calc names/values for plot title calc_names = [] for i in range(len(datas)): if ds.variables.mapping.get('gw') is not None: calc_names.append(mp.get_calc_label(calc, datas[i], ds['gw'].values)) else: calc_names.append(mp.get_calc_label(calc, datas[i])) # Get plot data and title if lat is not None and lon is not None: # is this a 1D of 2D lat/lon? dd = subsets[0].cf['latitude'].dims if len(dd) == 1: mp.title_lat = subsets[0][lat_coord_name].data[0] mp.title_lon = subsets[0][lon_coord_name].data[0] - 180 else: # 2 # lon should be 0- 360 mylat = subsets[0][lat_coord_name].data[0] mylon = subsets[0][lon_coord_name].data[0] if mylon < 0: mylon = mylon + 360 mp.title_lat = mylat mp.title_lon = mylon else: mp.title_lat = lat mp.title_lon = lon plot_datas = [] set_names = [] if calc_type in ['diff', 'ratio']: for i in range(1, len(raw_calcs)): plot_datas.append(mp.get_plot_data(raw_calcs[0], raw_calcs[i])) set_names.append(tex_escape(f'{sets[0]} & {sets[i]}')) else: for i in range(len(raw_calcs)): plot_datas.append(mp.get_plot_data(raw_calcs[i])) if calc_type in ['calc_of_diff']: set_names.append(tex_escape(f'{sets[0]} & {sets[i+1]}')) else: set_names.append(f'{sets[i]}') plot_dataset = xr.concat(plot_datas, 'sets') plot_dataset = plot_dataset.assign_coords({'sets': set_names}) titles = [] if calc_type in ['ratio', 'diff']: for i in range(1, len(calc_names)): titles.append(mp.get_title(calc_names[i], f'{sets[0]} & {sets[i]}')) elif calc_type in ['calc_of_diff']: for i in range(len(calc_names)): titles.append(mp.get_title(calc_names[i], f'{sets[0]} & {sets[i+1]}')) else: for i in range(len(calc_names)): titles.append(mp.get_title(calc_names[i], sets[i])) # Call plot functions if plot_type == 'spatial': mp.spatial_plot(plot_dataset, titles) elif plot_type == 'time_series': mp.time_series_plot(plot_dataset, titles) elif plot_type == 'histogram': mp.hist_plot(plot_dataset, titles) elif plot_type == 'periodogram': mp.periodogram_plot(plot_dataset, titles)