# -*- coding: utf-8 -*-
"""
This module contains the functions for calculating parameter based validation statistics for
a set of matchup database records.
"""
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
# Standard Python Dependencies
import os
import copy
# Non-Standard Python Dependencies
import numpy as np
import scipy.io as sio
from scipy.interpolate import interp1d
# Local Module Dependencies
from ocmw.core.graphics import plot_time_series, plot_correlation, plot_close
from ocmw.core.timeFuncs import temporalCoverage
from ocmw.core.statsTools import corr_coef, mean_bias, norm_mean_bias
from ocmw.core.statsTools import mean_absolute_error, norm_ma_error
from ocmw.core.statsTools import rms_error, norm_rmse, scatter_index
from ocmw.core.statsTools import circ_corr, circ_mean_bias, circ_norm_mean_bias
from ocmw.core.statsTools import circ_mean_abs_error, circ_norm_ma_error
from ocmw.core.statsTools import circ_rms_error, circ_norm_rmse
from ocmw.core.statsTools import circ_scatter_index
from ocmw.dataman.dataReaders import wavebuoy, emecwb, ww3
# Other Dependencies
# ----------------------------------------------------------------------------
# GLOBAL VARIABLES
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
# CLASS DEFINITIONS
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# ----------------------------------------------------------------------------
# ======================== Data Selection =================================
[docs]
def getVarName(record, Options: list):
"""
Get variable names based on processsing options list
"""
hasVar = [False]*len(Options)
for indx, vn in enumerate(Options):
hasVar[indx] = record.hasVar(vn)
varIndx = np.where(hasVar)[0]
if len(varIndx) != 0:
varName = Options[varIndx[0]]
else:
varName = ''
return varName
[docs]
def getSurfaceIndx(record, dataFmt):
"""
Get index of surface layer in depth data
"""
if dataFmt == 'InSituTAC':
depth = record.getVar('DEPH')
elif dataFmt == 'SeaDataNet':
depth = record.getVar('DEPTH')
dims = depth.shape
if len(dims) <= 1:
surfIndx = [0]
else:
surfIndx = np.where(depth[0, :] == 0.0)[0]
return surfIndx[0]
# =========================== Data Extraction =============================
# ========================== Data Interpolators ===========================
[docs]
def time_interpolate_obs(t_obs, t_var, t_model):
"""
Interpolation of observation data onto the model output timestamps
"""
intrp_var = np.empty_like(t_model, dtype='float')
intrp_var[:] = np.nan
for i_ref, t_ref in enumerate(t_model):
idx = np.sort(np.argsort(abs(t_obs-t_ref))[0:2])
o_diff = abs(np.diff(t_obs[idx]))*24.0
m_diff = np.min(abs(t_obs[idx]-t_ref))*24.0
if (o_diff < 1.5) and (m_diff < 1.0):
if t_ref < t_obs[idx[0]]:
intrp_var[i_ref] = t_var[idx[0]]
elif t_ref > t_obs[idx[1]]:
intrp_var[i_ref] = t_var[idx[1]]
else:
fnc = interp1d(t_obs[idx], t_var[idx],
kind='nearest', bounds_error=False)
intrp_var[i_ref] = fnc(t_ref)
return intrp_var
# ============================ Analysis Tools =============================
[docs]
def initialize_results():
"""
Initialise a dictionary for collecting validation results
"""
results = {}
results.update({'MODEL': []})
results.update({'OBSERVATION': []})
results.update({'YEAR': []})
results.update({'MONTH': []})
results.update({'PARAM': []})
results.update({'R': []})
results.update({'MB': []})
results.update({'NMB': []})
results.update({'MAE': []})
results.update({'NMAE': []})
results.update({'RMSE': []})
results.update({'NRMSE': []})
results.update({'SI': []})
results.update({'NSMPL': []})
return results
[docs]
def calculate_circular_metrics(model, obs, nsampl, degrees=True):
"""
Calculate circlar validation metrics for a pair of modelled and observation time series
"""
R = circ_corr(model, obs, degrees)
mb = circ_mean_bias(model, obs, degrees)
nmb = circ_norm_mean_bias(model, obs, degrees)
mae = circ_mean_abs_error(model, obs, degrees)
nmae = circ_norm_ma_error(model, obs, degrees)
rmse = circ_rms_error(model, obs, degrees)
nrmse = circ_norm_rmse(model, obs, degrees)
si = circ_scatter_index(model, obs, degrees)
metrics = {}
metrics['R'] = R
metrics['MB'] = mb
metrics['NMB'] = nmb
metrics['MAE'] = mae
metrics['NMAE'] = nmae
metrics['RMSE'] = rmse
metrics['NRMSE'] = nrmse
metrics['SI'] = si
metrics['NSMPL'] = nsampl
return metrics
[docs]
def calculate_metrics(model, obs, nsampl):
"""
Calculate standard validation metrics for a pair of modelled and observation time series
"""
R = corr_coef(model, obs)
mb = mean_bias(model, obs)
nmb = norm_mean_bias(model, obs)
mae = mean_absolute_error(model, obs)
nmae = norm_ma_error(model, obs)
rmse = rms_error(model, obs)
nrmse = norm_rmse(model, obs)
si = scatter_index(model, obs)
metrics = {}
metrics['R'] = R
metrics['MB'] = mb
metrics['NMB'] = nmb
metrics['MAE'] = mae
metrics['NMAE'] = nmae
metrics['RMSE'] = rmse
metrics['NRMSE'] = nrmse
metrics['SI'] = si
metrics['NSMPL'] = nsampl
return metrics
[docs]
def results_record(modelFileName, obsFileName, obsVarName,
year, month, metrics):
"""
Create results recod from calculated validation metrics
"""
results = {}
results['MODEL'] = modelFileName
results['BUOY'] = obsFileName
results['YEAR'] = year
results['MONTH'] = month
results['PARAM'] = obsVarName
for key in metrics.keys():
results[key] = metrics[key]
return results
[docs]
def store_valid_results(store, results):
"""
Store validation results record in the results dictionary
"""
store['MODEL'].append(results['MODEL'])
store['OBSERVATION'].append(results['BUOY'])
store['YEAR'].append(results['YEAR'])
store['MONTH'].append(results['MONTH'])
store['PARAM'].append(results['PARAM'])
store['R'].append(results['R'])
store['MB'].append(results['MB'])
store['NMB'].append(results['NMB'])
store['MAE'].append(results['MAE'])
store['NMAE'].append(results['NMAE'])
store['RMSE'].append(results['RMSE'])
store['NRMSE'].append(results['NRMSE'])
store['SI'].append(results['SI'])
store['NSMPL'].append(results['NSMPL'])
return
[docs]
def display_tabulated_results(store, mVarName):
"""
Display validation results as a table on the standard output
"""
print('Validation Statistics for', mVarName, '\n')
fmtstr = "{:<60}{:<25}{:^7}{:^7}{:^7}{:^9.3}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9}"
for row in zip(*([key] + (value) for key, value in store.items())):
print(fmtstr.format(*row))
return
[docs]
def save_tabulated_results(store, platform, mVarName, outpath):
"""
Save tabulated validation results as an ASCII file
"""
outtxtname = platform+'_VALIDATION_STATS_'+mVarName+'.txt'
outtxtfile = os.path.join(outpath,outtxtname)
with open(outtxtfile,'w') as ofile:
ofile.write('Validation Statistics for '+mVarName+'\n')
fmtstr = "{:<60}{:<25}{:^7}{:^7}{:^7}{:^9.3}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9}"
for row in zip(*([key] + (value) for key, value in store.items())):
ofile.write(fmtstr.format(*row)+'\n')
outmatname = platform+'_VALIDATION_STATS_'+mVarName+'.mat'
outmatfile = os.path.join(outpath,outmatname)
saveFlg = sio.savemat(outmatfile,store)
return saveFlg
[docs]
def save_cleaned_ts_data(t, rscd, buoy, platform, mVarName, outpath, year, month):
"""
Save cleaned matched model and buoy timeseries as a matlab \\*.mat binary file
"""
yrMnthStr = '_'+str(year)+'_'+str(month).zfill(2)
outdatname = platform+'_VALIDATION_TS_'+mVarName+yrMnthStr+'.mat'
data = {}
data['platform'] = platform
data['var'] = mVarName
data['year'] = year
data['month'] = month
data['time'] = t
data['rscd'] = rscd
data['buoy'] = buoy
outdatfile = os.path.join(outpath, outdatname)
saveFlg = sio.savemat(outdatfile, data)
return saveFlg
[docs]
def load_binary_results(platform, mVarName, datapath):
"""
load binary results from a matlab file
"""
matname = platform+'_VALIDATION_STATS_'+mVarName+'.mat'
matfile = os.path.join(datapath,matname)
data = sio.loadmat(matfile)
store = initialize_results()
for key in data.keys():
if key[0] != '_':
if len(data[key].shape) == 1:
store[key] = data[key].tolist()
else:
store[key] = data[key][0].tolist()
return store
[docs]
def process_record(rec,
buoyFmt: str, platform: str,
mVarName: str, varOptions: list,
pVarName: str, results_dir: str,
angularData=False, plot_results=False):
"""
Process a matched model/buoy data record
"""
rscdpath = copy.copy(rec[0])
rscdfname = copy.copy(rec[1])
rscdfile = os.path.join(rscdpath, rscdfname).replace('\\', '/')
rscd, rscd_t, rscd_var, wmtw = extract_model(rscdfile, mVarName)
#minNSampl = len(rscd_t)//2
minNSampl = 200
print(rscdfname)
buoypath = copy.copy(rec[2])
buoyfname = copy.copy(rec[3])
buoyfile = os.path.join(buoypath, buoyfname).replace('\\', '/')
bVarName, buoy, buoy_t, buoy_var = extract_obs(buoyfile, buoyFmt, varOptions)
print('Buoy file = ', buoyfname)
print('Buoy sampling period = ', buoy.sample_period)
year = copy.copy(rec[4])
month = copy.copy(rec[5])
valid_results = None
if bVarName != '':
print('Variable name = ', bVarName)
overlaps, tindx = wmtw.timeOverlap(buoy_t)
if (tindx[0] is not None) & (tindx[1] is not None):
ngood = len(np.where(buoy_var[tindx[0]:(tindx[1]+1)] != np.nan)[0])
else:
ngood = 0
if overlaps & (ngood >= minNSampl):
t0 = np.where(buoy_t >= rscd_t[0])[0][0]
t1 = np.where(buoy_t <= rscd_t[-1])[0][-1]
print('Number of good data points = ', ngood)
print('Points in RSCD time series = ', len(rscd_t))
print('Points in Buoy time series = ', len(buoy_t[t0:(t1+1)]))
print('Start times (model, in situ): ', rscd_t[0], buoy_t[t0])
print('End times (model, in situ): ', rscd_t[-1], buoy_t[t1])
model = []
obs = []
sample_period = np.median(np.diff(buoy_t[t0:(t1+1)]))*24.0
#if buoy.sample_period != 1.0:
if sample_period != 1.0:
# Get common indices
max_dt = np.max(np.diff(buoy_t[t0:(t1+1)]))
print('Maximum dt in obs =', max_dt*24., 'hours')
# Interpolate obs onto model times
t_mod = rscd_t.copy()
v_mod = rscd_var.copy()
t_obs = buoy_t[t0:(t1+1)].copy()
v_obs = buoy_var[t0:(t1+1)].copy()
v_obs = time_interpolate_obs(t_obs, v_obs, t_mod)
# Plot time series
if plot_results:
fig = plot_time_series(t_mod, v_mod, t_mod, v_obs,
bVarName, buoy.platform,
pVarName, rscdfname,
year, month,
[rscd_t[0], rscd_t[-1]])
plotname = ''.join(
('RSCD_', buoy.platform, '_',
pVarName,'_',
str(year),'_',
str(month).zfill(2),
'_TS.png'))
plotfile = os.path.join(results_dir,plotname)
fig.savefig(plotfile)
plot_close(fig)
# Select common points between timeseries
nans, x = np.isnan(v_obs), lambda z: z.nonzero()[0]
indices = x(~nans)
# Calculate validation metrics
model = v_mod[indices].copy()
obs = v_obs[indices].copy()
nsampl = len(indices)
save_cleaned_ts_data(t_mod, v_mod, v_obs, platform, mVarName, results_dir, year, month )
elif buoy.sample_period == 1.0:
# Get common indices
both = set(list(rscd_t)).intersection(list(buoy_t))
if len(both) > 0:
# Interpolate obs onto model times
t_mod = rscd_t.copy()
v_mod = rscd_var.copy()
t_obs = buoy_t[t0:(t1+1)].copy()
v_obs = buoy_var[t0:(t1+1)].copy()
# Plot time series
if plot_results:
fig = plot_time_series(t_mod, v_mod, t_obs, v_obs,
bVarName, buoy.platform,
pVarName, rscdfname,
year, month,
[rscd_t[0], rscd_t[-1]])
plotname = ''.join(
('RSCD_', buoy.platform, '_',
pVarName,'_',
str(year),'_',
str(month).zfill(2),
'_TS.png'))
plotfile = os.path.join(results_dir,plotname)
fig.savefig(plotfile)
plot_close(fig)
# Get common data indices
nans, x = np.isnan(v_obs), lambda z: z.nonzero()[0]
indices = x(~nans)
# Select common points between timeseries
model = v_mod[indices].copy()
obs = v_obs[indices].copy()
nsampl = len(both)
save_cleaned_ts_data(t_mod, v_mod, v_obs, platform, mVarName, results_dir, year, month )
# Calculate validation metrics
if (len(model) > minNSampl) & (len(obs) > minNSampl):
if angularData:
metrics = calculate_circular_metrics(model, obs, nsampl)
else:
metrics = calculate_metrics(model, obs, nsampl)
valid_results = results_record(rscdfname, buoyfname,
bVarName,
year, month, metrics)
# Plot correlation
if plot_results:
fig = plot_correlation(obs, model, bVarName,
buoy.platform,
pVarName, rscdfname,
year, month,
metrics)
plotname = ''.join(
('RSCD_', buoy.platform, '_',
pVarName,'_',
str(year),'_',
str(month).zfill(2),
'_CORR.png'))
plotfile = os.path.join(results_dir,plotname)
fig.savefig(plotfile)
plot_close(fig)
else:
print('No common indices found for metrics calculations.')
print()
else:
print('Insuffient valid data for statistics. (# Good = ',
ngood, ')')
print()
else:
print('No variable match with ', varOptions)
print('Buoy variables: ', buoy.file['vars'])
print()
return valid_results
[docs]
def validate_records(records, dataFmt: str, platform: str,
mVarName: str, varOptions: list,
pVarName,
results_dir,
plot_results=False):
"""
Process a set of matched model/buoy data records
"""
# Create results dictionary
results_store = initialize_results()
if mVarName == 'dir':
angularData=True
else:
angularData=False
n_valid_results = 0
for rec in records:
result = process_record(rec, dataFmt, platform,
mVarName, varOptions,
pVarName, results_dir,
angularData, plot_results)
if result is not None:
store_valid_results(results_store, result)
n_valid_results += 1
return n_valid_results, results_store