import collections
import csv
import os
import cf_xarray as cf
import dask
import numpy as np
import xarray as xr
from .calcs import Datasetcalcs, Diffcalcs
[docs]
def collect_datasets(
data_type, varnames, list_of_ds, labels, coords_ds=None, file_sizes=None, **kwargs
):
"""
Concatonate several different xarray datasets across a new
"collection" dimension, which can be accessed with the specified
labels. Stores them in an xarray dataset which can be passed to
the ldcpy plot functions (Call this OR open_datasets() before
plotting.)
Parameters
==========
data_type: string
Current data types: :cam-fv, pop, wrf
varnames : list
The variable(s) of interest to combine across input files (usually just one)
list_of_datasets : list
The xarray datasets to be concatonated into a collection
labels : list
The respective label to access data from each dataset (also used in plotting fcns)
coords_ds : xarray dataset
(optional) Specify an additional file that contains lat/lon corrds (common for WRF data)
file_sizes : list
(optional) sizes of files that each dataset corresponds to (used to print in compare_stats table
**kwargs :
(optional) – Additional arguments passed on to xarray.concat(). A list of available arguments can
be found here: https://xarray-test.readthedocs.io/en/latest/generated/xarray.concat.html
Returns
=======
out : xarray.Dataset
a collection containing all the data from the list datasets
Notes
======
-WRF data must be postprocessed with xWRF before passing to ldcpy
(e.g., ds = xr.open_dataset(wrf_file, engine="netcdf4").xwrf.postprocess())
-For now lat/lon info must be in the same file!
"""
# Error checking:
# list_of_files and labels must be same length
assert len(list_of_ds) == len(
labels
), 'ERROR:collect_dataset dataset list and labels arguments must be the same length'
# the number of timeslices must be the same
sz = np.zeros(len(list_of_ds))
for i, myds in enumerate(list_of_ds):
time_name = myds.cf.coordinates['time'][0]
sz[i] = myds.sizes[time_name]
indx = np.unique(sz)
assert indx.size == 1, 'ERROR: all datasets must have the same length time dimension'
# file sizes?
if file_sizes is not None:
assert len(file_sizes) == len(
labels
), 'ERROR::collect_dataset dataset list and file sizes arguments must be the same length'
# wrf data must contain lat/lon info in same file if a coord_file is not specified
if data_type == 'wrf':
if coords_ds is None:
latlon_found = np.zeros(len(list_of_ds))
for i, myds in enumerate(list_of_ds):
# XLAT,XLONG,XLAT_U,XLONG_U,XLAT_V,XLONG_V
for j in myds.coords.keys():
if j == 'XLAT' or j == 'XLONG':
latlon_found[i] += 1
indx = np.where(latlon_found > 1)[0]
assert len(indx) == len(list_of_ds), 'ERROR: WRF datasets must contain XLAT and XLONG'
else: # has a coords ds
ds_notime = coords_ds.drop_dims('Time')
# copy coords to EACH of the datasets
for i, myds in enumerate(list_of_ds):
ds_new = myds.assign_coords(ds_notime.coords)
list_of_ds[i] = ds_new.copy(deep=True)
# weights?
if data_type == 'cam-fv':
weights_name = 'gw'
if weights_name in list_of_ds[0].variables:
varnames.append(weights_name)
else:
weights_name = None
elif data_type == 'pop':
weights_name = 'TAREA'
varnames.append(weights_name)
elif data_type == 'wrf':
weights_name = None
if weights_name is None:
weighted = False
else:
weighted = True
# preprocess_vars is here for working on jupyter hub...
def preprocess_vars(ds, varnames):
return ds[varnames]
# preprocess
for i, myds in enumerate(list_of_ds):
list_of_ds[i] = preprocess_vars(myds, varnames)
full_ds = xr.concat(list_of_ds, 'collection', **kwargs)
if data_type == 'pop':
full_ds.coords['cell_area'] = xr.DataArray(full_ds.variables.mapping.get(weights_name))[0]
elif data_type == 'cam-fv':
full_ds.coords['cell_area'] = (
xr.DataArray(full_ds.variables.mapping.get(weights_name))
.expand_dims(lon=full_ds.dims['lon'])
.transpose()
)
full_ds.attrs['cell_measures'] = 'area: cell_area'
if weights_name:
full_ds = full_ds.drop(weights_name)
full_ds['collection'] = xr.DataArray(labels, dims='collection')
print('dataset size in GB {:0.2f}\n'.format(full_ds.nbytes / 1e9))
full_ds.attrs['data_type'] = data_type
full_ds.attrs['file_size'] = None
full_ds.attrs['weighted'] = weighted
# file sizes?
if file_sizes is not None:
file_size_dict = {}
for i, myfile in enumerate(list_of_ds):
file_size_dict[labels[i]] = file_sizes[i]
full_ds.attrs['file_size'] = file_size_dict
# from other copy of this function
for v in varnames[:-1]:
new_ds = []
i = 0
for label in labels:
new_ds.append(full_ds[v].sel(collection=label))
new_ds[i].attrs['data_type'] = data_type
new_ds[i].attrs['set_name'] = label
new_ds[i].attrs['weighted'] = weighted
# d = xr.combine_by_coords(new_ds)
d = xr.concat(new_ds, 'collection')
full_ds[v] = d
return full_ds
[docs]
def combine_datasets(ds_list):
new_ds = ds_list[0]
for ds in ds_list[1:]:
for var in ds.data_vars.variables.mapping:
new_ds[var] = ds[var]
return new_ds
[docs]
def open_datasets(data_type, varnames, list_of_files, labels, weights=True, **kwargs):
"""
Open several different netCDF files, concatenate across
a new 'collection' dimension, which can be accessed with the specified
labels. Stores them in an xarray dataset which can be passed to the ldcpy
plot functions.
Parameters
==========
data_type: string
Current data types: :cam-fv, pop
varnames : list
The variable(s) of interest to combine across input files (usually just one)
list_of_files : list
The file paths for the netCDF file(s) to be opened
labels : list
The respective label to access data from each netCDF file (also used in plotting fcns)
**kwargs :
(optional) – Additional arguments passed on to xarray.open_mfdataset(). A list of available arguments can
be found here: http://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html
Returns
=======
out : xarray.Dataset
a collection containing all the data from the list of files
Notes
======
wrf netcdf data must be postprocessed with xwrf, e.g.
ds = xr.open_dataset(wrf_file, engine="netcdf4").xwrf.postprocess()
So need to use collect_data instead.
"""
# Error checking:
# list_of_files and labels must be same length
assert len(list_of_files) == len(
labels
), 'ERROR: open_dataset file list and labels arguments must be the same length'
# can't use wrf wwith this function
assert (
data_type != 'wrf'
), 'ERROR: WRF files must be postprocessed with xWRF and passed to collect_dataset'
# all must have the same time dimension
sz = np.zeros(len(list_of_files))
file_size_dict = {}
for i, myfile in enumerate(list_of_files):
myds = xr.open_dataset(myfile)
sz[i] = myds.sizes['time']
myds.close()
fs = os.path.getsize(myfile)
file_size_dict[labels[i]] = fs
# indx = np.unique(sz)
# min_time_steps = int(np.min(sz)) # Find the minimum size among all time dimensions
# assert indx.size == 1, 'ERROR: all files must have the same length time dimension'
# preprocess_vars is here for working on jupyter hub...
def preprocess_vars(ds):
# Trim the dataset to the minimum number of time steps
ds = ds.isel(time=slice(0, 25100))
# Select the specified variables
return ds[varnames]
# check the weights
tmp_ds = xr.open_dataset(list_of_files[0])
if data_type == 'cam-fv' and weights is True:
weights_name = 'gw'
if weights_name in tmp_ds.variables:
varnames.append(weights_name)
else:
weights_name = None
elif data_type == 'pop' and weights is True:
weights_name = 'TAREA'
varnames.append(weights_name)
elif data_type == 'wrf':
weights = False
weights_name = None
if weights_name is None:
weighted = False
else:
weighted = True
full_ds = xr.open_mfdataset(
list_of_files,
concat_dim='collection',
combine='nested',
data_vars=varnames,
parallel=True,
preprocess=preprocess_vars,
**kwargs,
)
if data_type == 'pop' and weights is True:
full_ds.coords['cell_area'] = xr.DataArray(full_ds.variables.mapping.get(weights_name))[0]
elif data_type == 'cam-fv' and weights is True:
full_ds.coords['cell_area'] = (
xr.DataArray(full_ds.variables.mapping.get(weights_name))
.expand_dims(lon=full_ds.sizes['lon'])
.transpose()
)
full_ds.attrs['cell_measures'] = 'area: cell_area'
if weights is True:
full_ds = full_ds.drop(weights_name)
full_ds['collection'] = xr.DataArray(labels, dims='collection')
print('dataset size in GB {:0.2f}\n'.format(full_ds.nbytes / 1e9))
full_ds.attrs['data_type'] = data_type
full_ds.attrs['file_size'] = file_size_dict
full_ds.attrs['weighted'] = weighted
for v in varnames[:-1]:
new_ds = []
i = 0
for label in labels:
new_ds.append(full_ds[v].sel(collection=label))
new_ds[i].attrs['data_type'] = data_type
new_ds[i].attrs['set_name'] = label
new_ds[i].attrs['weighted'] = weighted
# d = xr.combine_by_coords(new_ds)
d = xr.concat(new_ds, 'collection')
full_ds[v] = d
return full_ds
[docs]
def compare_stats(
ds,
varname: str,
sets,
significant_digits: int = 5,
include_ssim: bool = False,
weighted: bool = True,
**calcs_kwargs,
):
"""
Print error summary statistics for multiple DataArrays (should just be a single time slice)
Parameters
==========
ds : xarray.Dataset
An xarray dataset containing multiple netCDF files concatenated across a 'collection' dimension
varname : str
The variable of interest in the dataset
sets: list of str
The labels of the collection to compare (all will be compared to the first set)
significant_digits : int, optional
The number of significant digits to use when printing stats (default 5)
include_ssim : bool, optional
Whether or not to compute the image ssim - slow for 3D vars (default: False)
weighted : bool, optional
Whether or not weight the means (default = True)
**calcs_kwargs :
Additional keyword arguments passed through to the
:py:class:`~ldcpy.Datasetcalcs` instance.
Returns
=======
out : None
"""
# get a datarray for the variable of interest and get collections
# (this is done seperately to work with cf_xarray)
da = ds[varname]
data_type = ds.attrs['data_type']
attr_weighted = ds.attrs['weighted']
# no weights for wrf
if data_type == 'cam-fv':
if weighted:
if not attr_weighted:
print(
'Warning - this data does not contain weights, so averages will be unweighted.'
)
weighted = False
if data_type == 'wrf':
weighted = False
file_size_dict = ds.attrs['file_size']
if file_size_dict is None:
include_file_size = False
else:
include_file_size = True
da.attrs['cell_measures'] = 'area: cell_area'
# use this after the update instead of above
# da = ds.cf.data_vars[varname]
# do we have more than one time slice? SHould only have one..
if 'time' in da.dims:
print('Warning - this data set has a time dimension - examining slice 0 only...')
da = da.isel(time=0)
# see how many sets we have
da_sets = []
num = len(sets)
if num < 2:
print('Error: must specify at least two sets to compare!')
return
for set in sets:
da_sets.append(da.sel(collection=set))
dd_sets = []
for i in range(1, num):
dd_sets.append(da_sets[0] - da_sets[i])
aggregate_dims = calcs_kwargs.pop('aggregate_dims', None)
da_set_calcs = []
for i in range(num):
da_set_calcs.append(
Datasetcalcs(da_sets[i], data_type, aggregate_dims, **calcs_kwargs, weighted=weighted)
)
dd_set_calcs = []
for i in range(num - 1):
dd_set_calcs.append(
Datasetcalcs(dd_sets[i], data_type, aggregate_dims, **calcs_kwargs, weighted=weighted)
)
diff_calcs = []
for i in range(1, num):
diff_calcs.append(
Diffcalcs(
da_sets[0], da_sets[i], data_type, aggregate_dims, **calcs_kwargs, weighted=weighted
)
)
# are the arrays using dask
if da_sets[0].chunks is not None:
using_dask = True
else:
using_dask = False
# DATA FRAME
import pandas as pd
from IPython.display import HTML, display
df_dict = {}
my_cols = []
for i in range(num):
my_cols.append(sets[i])
temp_mean = []
temp_var = []
temp_std = []
temp_min = []
temp_max = []
temp_min_abs_nonzero = []
temp_pos = []
temp_zeros = []
temp_ac_lat = []
temp_ac_lon = []
temp_entropy = []
temp_info = []
for i in range(num):
# only use compute if it's a dask array
temp_return = da_set_calcs[i].get_calc('mean').data
if using_dask:
temp_mean.append(temp_return.compute())
else:
temp_mean.append(temp_return)
temp_return = da_set_calcs[i].get_calc('variance').data
if using_dask:
temp_var.append(temp_return.compute())
else:
temp_var.append(temp_return)
temp_return = da_set_calcs[i].get_calc('std').data
if using_dask:
temp_std.append(temp_return.compute())
else:
temp_std.append(temp_return)
temp_return = da_set_calcs[i].get_calc('max_val').data
if using_dask:
temp_max.append(temp_return.compute())
else:
temp_max.append(temp_return)
temp_return = da_set_calcs[i].get_calc('min_val').data
if using_dask:
temp_min.append(temp_return.compute())
else:
temp_min.append(temp_return)
temp_return = da_set_calcs[i].get_calc('min_abs_nonzero').data
if using_dask:
temp_min_abs_nonzero.append(temp_return.compute())
else:
temp_min_abs_nonzero.append(temp_return)
temp_return = da_set_calcs[i].get_calc('prob_positive').data
if using_dask:
temp_pos.append(temp_return.compute())
else:
temp_pos.append(temp_return)
temp_return = da_set_calcs[i].get_calc('num_zero').data
if using_dask:
temp_zeros.append(temp_return.compute())
else:
temp_zeros.append(temp_return)
temp_info.append(da_set_calcs[i].get_single_calc('real_information_cutoff'))
if data_type == 'cam-fv':
temp_ac_lat.append(da_set_calcs[i].get_single_calc('lat_autocorr'))
temp_ac_lon.append(da_set_calcs[i].get_single_calc('lon_autocorr'))
temp_entropy.append(da_set_calcs[i].get_single_calc('entropy'))
df_dict['mean'] = temp_mean
df_dict['variance'] = temp_var
df_dict['standard deviation'] = temp_std
df_dict['min value'] = temp_min
df_dict['min (abs) nonzero value'] = temp_min_abs_nonzero
df_dict['max value'] = temp_max
df_dict['probability positive'] = temp_pos
df_dict['number of zeros'] = temp_zeros
df_dict['99% real information cutoff bit'] = temp_info
if data_type == 'cam-fv':
df_dict['spatial autocorr - latitude'] = temp_ac_lat
df_dict['spatial autocorr - longitude'] = temp_ac_lon
df_dict['entropy estimate'] = temp_entropy
for d in df_dict.keys():
fo = [f'%.{significant_digits}g' % item for item in df_dict[d]]
df_dict[d] = fo
df = pd.DataFrame.from_dict(df_dict, orient='index', columns=my_cols)
display(
HTML(
' <span style="color:blue">Variable: <b>{varname}</b> </span> '.format(varname=varname)
)
)
display(HTML(' <span style="color:green">Comparison: </span> '))
display(df)
df_dict2 = {}
my_cols2 = []
for i in range(1, num):
my_cols2.append(sets[i])
temp_max_abs = []
temp_min_abs = []
temp_mean_abs = []
temp_mean_sq = []
temp_rms = []
for i in range(num - 1):
# temp_max_abs.append(dd_set_calcs[i].get_calc('max_abs').data.compute())
temp_return = dd_set_calcs[i].get_calc('max_abs').data
if using_dask:
temp_max_abs.append(temp_return.compute())
else:
temp_max_abs.append(temp_return)
# temp_min_abs.append(dd_set_calcs[i].get_calc('min_abs').data.compute())
temp_return = dd_set_calcs[i].get_calc('min_abs').data
if using_dask:
temp_min_abs.append(temp_return.compute())
else:
temp_min_abs.append(temp_return)
# temp_mean_abs.append(dd_set_calcs[i].get_calc('mean_abs').data.compute())
temp_return = dd_set_calcs[i].get_calc('mean_abs').data
if using_dask:
temp_mean_abs.append(temp_return.compute())
else:
temp_mean_abs.append(temp_return)
# temp_mean_sq.append(dd_set_calcs[i].get_calc('mean_squared').data.compute())
temp_return = dd_set_calcs[i].get_calc('mean_squared').data
if using_dask:
temp_mean_sq.append(temp_return.compute())
else:
temp_mean_sq.append(temp_return)
# temp_rms.append(dd_set_calcs[i].get_calc('rms').data.compute())
temp_return = dd_set_calcs[i].get_calc('rms').data
if using_dask:
temp_rms.append(temp_return.compute())
else:
temp_rms.append(temp_return)
df_dict2['max abs diff'] = temp_max_abs
df_dict2['min abs diff'] = temp_min_abs
df_dict2['mean abs diff'] = temp_mean_abs
df_dict2['mean squared diff'] = temp_mean_sq
df_dict2['root mean squared diff'] = temp_rms
temp_nrms = []
temp_max_pe = []
temp_pcc = []
temp_ks = []
temp_sre = []
temp_max_spr = []
temp_data_ssim = []
temp_ssim = []
temp_cr = []
# compare to the first set
if include_file_size:
fs_orig = file_size_dict[sets[0]]
rel_errors = [0.0001, 0.001, 0.05, 0.01]
for i in range(num - 1):
temp_nrms.append(diff_calcs[i].get_diff_calc('n_rms'))
temp_max_pe.append(diff_calcs[i].get_diff_calc('n_emax'))
temptemp = diff_calcs[i].get_diff_calc('pearson_correlation_coefficient')
temp_pcc.append(temptemp)
temp_ks.append(diff_calcs[i].get_diff_calc('ks_p_value'))
diff_calcs[i].spre_tol = rel_errors[0]
temp_sre.append(diff_calcs[i].get_diff_calc('spatial_rel_error'))
temp_max_spr.append(diff_calcs[i].get_diff_calc('max_spatial_rel_error'))
temp_data_ssim.append(diff_calcs[i].get_diff_calc('ssim_fp'))
if include_ssim:
temp_ssim.append(diff_calcs[i].get_diff_calc('ssim'))
if include_file_size:
this_fs = file_size_dict[my_cols2[i]]
temp_cr.append(round(fs_orig / this_fs, 2))
df_dict2['normalized root mean squared diff'] = temp_nrms
df_dict2['normalized max pointwise error'] = temp_max_pe
df_dict2['pearson correlation coefficient'] = temp_pcc
df_dict2['ks p-value'] = temp_ks
tmp_str = 'spatial relative error(% > ' + str(rel_errors[0]) + ')'
df_dict2[tmp_str] = temp_sre
df_dict2['max spatial relative error'] = temp_max_spr
df_dict2['DSSIM'] = temp_data_ssim
if include_ssim:
df_dict2['image SSIM'] = temp_ssim
if include_file_size:
df_dict2['file size ratio'] = temp_cr
for d in df_dict2.keys():
fo = [f'%.{significant_digits}g' % item for item in df_dict2[d]]
df_dict2[d] = fo
df2 = pd.DataFrame.from_dict(df_dict2, orient='index', columns=my_cols2)
display(HTML('<br>'))
display(HTML('<span style="color:green">Difference calcs: </span> '))
display(df2)
[docs]
def check_metrics(
ds,
varname,
set1,
set2,
ks_tol=0.05,
pcc_tol=0.99999,
spre_tol=5.0,
ssim_tol=0.995,
**calcs_kwargs,
):
"""
Check the K-S, Pearson Correlation, and Spatial Relative Error calcs
Parameters
==========
ds : xarray.Dataset
An xarray dataset containing multiple netCDF files concatenated across a 'collection' dimension
varname : str
The variable of interest in the dataset
set1 : str
The collection label of the "control" data
set2 : str
The collection label of the (1st) data to compare
ks_tol : float, optional
The p-value threshold (significance level) for the K-S test (default = .05)
pcc_tol: float, optional
The default Pearson corrolation coefficient (default = .99999)
spre_tol: float, optional
The percentage threshold for failing grid points in the spatial relative error test (default = 5.0).
ssim_tol: float, optional
The threshold for the data ssim test (default = .995
**calcs_kwargs :
Additional keyword arguments passed through to the
:py:class:`~ldcpy.Datasetcalcs` instance.
Returns
=======
out : Number of failing calcs
Notes
======
Check the K-S, Pearson Correlation, and Spatial Relative Error calcs from:
A. H. Baker, H. Xu, D. M. Hammerling, S. Li, and J. Clyne,
“Toward a Multi-method Approach: Lossy Data Compression for
Climate Simulation Data”, in J.M. Kunkel et al. (Eds.): ISC
High Performance Workshops 2017, Lecture Notes in Computer
Science 10524, pp. 30–42, 2017 (doi:10.1007/978-3-319-67630-2_3).
Check the Data SSIM, which is a modification of SSIM calc from:
A.H. Baker, D.M. Hammerling, and T.L. Turton. “Evaluating image
quality measures to assess the impact of lossy data compression
applied to climate simulation data”, Computer Graphics Forum 38(3),
June 2019, pp. 517-528 (doi:10.1111/cgf.13707).
Default tolerances for the tests are:
------------------------
K-S: fail if p-value < .05 (significance level)
Pearson correlation coefficient: fail if coefficient < .99999
Spatial relative error: fail if > 5% of grid points fail relative error
Data SSIM: fail if Data SSIM < .995
"""
print(f'Evaluating 4 calcs for {set1} data (set1) and {set2} data (set2):')
data_type = ds.attrs['data_type']
aggregate_dims = calcs_kwargs.pop('aggregate_dims', None)
diff_calcs = Diffcalcs(
ds[varname].sel(collection=set1),
ds[varname].sel(collection=set2),
data_type,
aggregate_dims,
**calcs_kwargs,
weighted=False,
)
# count the number of failures
num_fail = 0
# Pearson less than pcc_tol means fail
pcc = diff_calcs.get_diff_calc('pearson_correlation_coefficient')
if pcc < pcc_tol:
print(
' *FAILED pearson correlation coefficient test...(pcc = {0:.5f}'.format(pcc.values),
')',
)
num_fail = num_fail + 1
else:
print(
' PASSED pearson correlation coefficient test...(pcc = {0:.5f}'.format(pcc.values),
')',
)
# K-S p-value less than ks_tol means fail (can reject null hypo)
ks = diff_calcs.get_diff_calc('ks_p_value')
if ks < ks_tol:
print(' *FAILED ks test...(ks p_val = {0:.5f}'.format(ks.values), ')')
num_fail = num_fail + 1
else:
print(' PASSED ks test...(ks p_val = {0:.5f}'.format(ks.values), ')')
# Spatial rel error fails if more than spre_tol
spre = diff_calcs.get_diff_calc('spatial_rel_error')
if spre > spre_tol:
print(
' *FAILED spatial relative error test ... (spre = {0:.2f}'.format(spre.values),
' %)',
)
num_fail = num_fail + 1
else:
print(
' PASSED spatial relative error test ...(spre = {0:.2f}'.format(spre.values), ' %)'
)
# SSIM less than of ssim_tol is failing
ssim_val = diff_calcs.get_diff_calc('ssim_fp')
if ssim_val < ssim_tol:
print(' *FAILED DATA SSIM test ... (ssim = {0:.5f}'.format(ssim_val), ')')
num_fail = num_fail + 1
else:
print(' PASSED DATA SSIM test ... (ssim = {0:.5f}'.format(ssim_val), ')')
if num_fail > 0:
print(f'WARNING: {num_fail} of 4 tests failed.')
return num_fail
[docs]
def subset_data(
ds,
subset=None,
lat=None,
lon=None,
lev=None,
start=None,
end=None,
time_dim_name=None,
vertical_dim_name=None,
lat_coord_name=None,
lon_coord_name=None,
):
"""
Get a subset of the given dataArray, returns a dataArray
"""
ds_subset = ds
# print(ds.cf.describe())
if lon_coord_name is None:
lon_coord_name = ds.cf.coordinates['longitude'][0]
if lat_coord_name is None:
lat_coord_name = ds.cf.coordinates['latitude'][0]
if vertical_dim_name is None:
try:
vert = ds.cf['vertical']
except KeyError:
vert = None
if vert is not None:
vertical_dim_name = ds.cf.coordinates['vertical'][0]
# print(lat_coord_name, lon_coord_name, vertical_dim_name)
if time_dim_name is None:
time_dim_name = ds.cf.coordinates['time'][0]
latdim = ds_subset.cf[lon_coord_name].ndim
# need dim names
dd = ds_subset.cf['latitude'].dims
if latdim == 1:
lat_dim_name = dd[0]
lon_dim_name = ds_subset.cf['longitude'].dims[0]
elif latdim == 2:
lat_dim_name = dd[0]
lon_dim_name = dd[1]
if start is not None and end is not None:
ds_subset = ds_subset.isel({time_dim_name: slice(start, end + 1)})
if subset is not None:
if subset == 'DJF':
ds_subset = ds_subset.cf.sel(time=ds.cf['time'].dt.season == 'DJF')
elif subset == 'MAM':
ds_subset = ds_subset.cf.sel(time=ds.cf['time'].dt.season == 'MAM')
elif subset == 'JJA':
ds_subset = ds_subset.cf.sel(time=ds.cf['time'].dt.season == 'JJA')
elif subset == 'SON':
ds_subset = ds_subset.cf.sel(time=ds.cf['time'].dt.season == 'SON')
elif subset == 'first5':
ds_subset = ds_subset.isel({time_dim_name: slice(None, 5)})
if lev is not None:
if vertical_dim_name in ds_subset.dims:
ds_subset = ds_subset.isel({vertical_dim_name: lev})
if latdim == 1:
if lat is not None:
ds_subset = ds_subset.sel(**{lat_coord_name: [lat], 'method': 'nearest'})
if lon is not None:
ds_subset = ds_subset.sel(**{lon_coord_name: [lon + 180], 'method': 'nearest'})
elif latdim == 2:
if lat is not None:
if lon is not None:
if ds_subset.data_type == 'pop':
# lat is -90 to 90
# lon should be 0- 360
ad_lon = lon
if ad_lon < 0:
ad_lon = ad_lon + 360
mlat = ds_subset[lat_coord_name].compute()
mlon = ds_subset[lon_coord_name].compute()
# euclidean dist for now....
di = np.sqrt(np.square(ad_lon - mlon) + np.square(lat - mlat))
index = np.where(di == np.min(di))
xmin = index[0][0]
ymin = index[1][0]
# Don't want if it's a land point
check = ds_subset.isel(nlat=xmin, nlon=ymin, time=1).compute()
if np.isnan(check):
print(
'You have chosen a lat/lon point with Nan values (i.e., a land point). Plot will not make sense.'
)
ds_subset = ds_subset.isel({lat_dim_name: [xmin], lon_dim_name: [ymin]})
elif ds_subset.data_type == 'wrf':
# print(lat_dim_name, lon_dim_name)
ad_lon = lon
mlat = ds_subset[lat_coord_name].compute()
mlon = ds_subset[lon_coord_name].compute()
di = np.sqrt(np.square(ad_lon - mlon) + np.square(lat - mlat))
index = np.where(di == np.min(di))
xmin = index[0][0]
ymin = index[1][0]
# print(xmin, ymin)
# TO DO: add error checking for if it's out of bounds
# check = ds_subset.isel({lat_dim_name: [xmin], lon_dim_name : [ymin], time_dim_name: [1]}).compute()
# print(check)
ds_subset = ds_subset.isel({lat_dim_name: [xmin], lon_dim_name: [ymin]})
# ds_subset.compute()
return ds_subset
[docs]
def var_and_wt_coords(varname, ds_col):
ca_coord = ds_col.coords['cell_area']
if dask.is_dask_collection(ca_coord):
ca_coord = ca_coord.compute()
ds0 = ds_col.cf[varname]
all_coords = ds0.coords
all_coords['cell_area'] = ca_coord
ds0.assign_coords(all_coords)
ds0.attrs['cell_measures'] = 'area: cell_area'
return ds0
[docs]
def save_metrics(
full_ds,
varname,
set1,
set2,
time=0,
lev=0,
location='names.csv',
):
"""
full_ds : xarray.Dataset
An xarray dataset containing multiple netCDF files concatenated across a 'collection' dimension
varname : str
The variable of interest in the dataset
set1 : str
The collection label of the "control" data
set2 : str
The collection label of the (1st) data to compare
time : int, optional
The time index used t (default = 0)
time : lev, optional
The level index of interest in a 3D dataset (default 0)
Returns
=======
out : Number of failing metrics
"""
ds = subset_data(full_ds)
# count the number of failuress
num_fail = 0
print(
'Evaluating metrics for {} data (set1) and {} data (set2), time {}'.format(
set1, set2, time
),
':',
)
diff_metrics = Diffcalcs(
ds[varname].sel(collection=set1).isel(time=time),
ds[varname].sel(collection=set2).isel(time=time),
'cam-fv',
['lat', 'lon'],
)
# SSIM
ssim_fp_val = diff_metrics.get_diff_calc('ssim_fp')
file_exists = os.path.isfile(location)
with open(location, 'a', newline='') as csvfile:
fieldnames = [
'set',
'time',
'ssim_fp',
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if not file_exists:
writer.writeheader()
writer.writerow(
{
'set': set2,
'time': time,
'ssim_fp': ssim_fp_val,
}
)
return num_fail