Source code for ocmw.calval.ww3WavesValidation

# -*- coding: utf-8 -*-
"""
This module contains the functions for calculating parameter based validation statistics for 
a set of matchup database records. 

"""
# ----------------------------------------------------------------------------
#   IMPORTS
# ----------------------------------------------------------------------------
# Standard Python Dependencies
import os
import copy
# Non-Standard Python Dependencies
import numpy as np
import scipy.io as sio
from scipy.interpolate import interp1d
# Local Module Dependencies
from ocmw.core.graphics import plot_time_series, plot_correlation, plot_close
from ocmw.core.timeFuncs import temporalCoverage
from ocmw.core.statsTools import corr_coef, mean_bias, norm_mean_bias
from ocmw.core.statsTools import mean_absolute_error, norm_ma_error
from ocmw.core.statsTools import rms_error, norm_rmse, scatter_index
from ocmw.core.statsTools import circ_corr, circ_mean_bias, circ_norm_mean_bias
from ocmw.core.statsTools import circ_mean_abs_error, circ_norm_ma_error
from ocmw.core.statsTools import circ_rms_error, circ_norm_rmse
from ocmw.core.statsTools import circ_scatter_index
from ocmw.dataman.dataReaders import wavebuoy, emecwb, ww3
# Other Dependencies

# ----------------------------------------------------------------------------
#   GLOBAL VARIABLES
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
#   CLASS DEFINITIONS
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
#   FUNCTION DEFINITIONS
# ----------------------------------------------------------------------------
# ======================== Data Selection =================================
[docs] def getVarName(record, Options: list): """ Get variable names based on processsing options list """ hasVar = [False]*len(Options) for indx, vn in enumerate(Options): hasVar[indx] = record.hasVar(vn) varIndx = np.where(hasVar)[0] if len(varIndx) != 0: varName = Options[varIndx[0]] else: varName = '' return varName
[docs] def getSurfaceIndx(record, dataFmt): """ Get index of surface layer in depth data """ if dataFmt == 'InSituTAC': depth = record.getVar('DEPH') elif dataFmt == 'SeaDataNet': depth = record.getVar('DEPTH') dims = depth.shape if len(dims) <= 1: surfIndx = [0] else: surfIndx = np.where(depth[0, :] == 0.0)[0] return surfIndx[0]
# =========================== Data Extraction =============================
[docs] def extract_model(rscdfile, mVarName: str): """ Extract model output for a specified variable """ rscd = ww3(rscdfile) wmtw = temporalCoverage(rscd.time_start, rscd.time_end) rscd_t = rscd.getVar('time') if (mVarName == 'fp') or (mVarName == 'f02'): rscd_var = 1.0/rscd.getVar(mVarName) else: rscd_var = rscd.getVar(mVarName) # if mVarName == 'dir': # theta = rscd_var.copy()*np.pi/180.0 return rscd, rscd_t, rscd_var, wmtw
[docs] def extract_obs(buoyfile, buoyDataFmt: str, varOptions: list): """ Extract *in situ* buoy data for a specified data foramt and variable """ if buoyDataFmt == 'InSituTAC': buoy =wavebuoy(buoyfile) bVarName = getVarName(buoy, varOptions) elif buoyDataFmt == 'SeaDataNet': buoy = emecwb(buoyfile) bVarName = getVarName(buoy, varOptions) else: print('Unknown file format: '+buoyDataFmt) bVarName = '' if bVarName != '': varIndx = getSurfaceIndx(buoy, buoyDataFmt) buoy_t, buoy_var = buoy.getQCTimeseries(bVarName, varIndx) else: varIndx = None buoy_t = [] buoy_var = [] return bVarName, buoy, buoy_t, buoy_var
# ========================== Data Interpolators ===========================
[docs] def time_interpolate_obs(t_obs, t_var, t_model): """ Interpolation of observation data onto the model output timestamps """ intrp_var = np.empty_like(t_model, dtype='float') intrp_var[:] = np.nan for i_ref, t_ref in enumerate(t_model): idx = np.sort(np.argsort(abs(t_obs-t_ref))[0:2]) o_diff = abs(np.diff(t_obs[idx]))*24.0 m_diff = np.min(abs(t_obs[idx]-t_ref))*24.0 if (o_diff < 1.5) and (m_diff < 1.0): if t_ref < t_obs[idx[0]]: intrp_var[i_ref] = t_var[idx[0]] elif t_ref > t_obs[idx[1]]: intrp_var[i_ref] = t_var[idx[1]] else: fnc = interp1d(t_obs[idx], t_var[idx], kind='nearest', bounds_error=False) intrp_var[i_ref] = fnc(t_ref) return intrp_var
# ============================ Analysis Tools =============================
[docs] def initialize_results(): """ Initialise a dictionary for collecting validation results """ results = {} results.update({'MODEL': []}) results.update({'OBSERVATION': []}) results.update({'YEAR': []}) results.update({'MONTH': []}) results.update({'PARAM': []}) results.update({'R': []}) results.update({'MB': []}) results.update({'NMB': []}) results.update({'MAE': []}) results.update({'NMAE': []}) results.update({'RMSE': []}) results.update({'NRMSE': []}) results.update({'SI': []}) results.update({'NSMPL': []}) return results
[docs] def calculate_circular_metrics(model, obs, nsampl, degrees=True): """ Calculate circlar validation metrics for a pair of modelled and observation time series """ R = circ_corr(model, obs, degrees) mb = circ_mean_bias(model, obs, degrees) nmb = circ_norm_mean_bias(model, obs, degrees) mae = circ_mean_abs_error(model, obs, degrees) nmae = circ_norm_ma_error(model, obs, degrees) rmse = circ_rms_error(model, obs, degrees) nrmse = circ_norm_rmse(model, obs, degrees) si = circ_scatter_index(model, obs, degrees) metrics = {} metrics['R'] = R metrics['MB'] = mb metrics['NMB'] = nmb metrics['MAE'] = mae metrics['NMAE'] = nmae metrics['RMSE'] = rmse metrics['NRMSE'] = nrmse metrics['SI'] = si metrics['NSMPL'] = nsampl return metrics
[docs] def calculate_metrics(model, obs, nsampl): """ Calculate standard validation metrics for a pair of modelled and observation time series """ R = corr_coef(model, obs) mb = mean_bias(model, obs) nmb = norm_mean_bias(model, obs) mae = mean_absolute_error(model, obs) nmae = norm_ma_error(model, obs) rmse = rms_error(model, obs) nrmse = norm_rmse(model, obs) si = scatter_index(model, obs) metrics = {} metrics['R'] = R metrics['MB'] = mb metrics['NMB'] = nmb metrics['MAE'] = mae metrics['NMAE'] = nmae metrics['RMSE'] = rmse metrics['NRMSE'] = nrmse metrics['SI'] = si metrics['NSMPL'] = nsampl return metrics
[docs] def results_record(modelFileName, obsFileName, obsVarName, year, month, metrics): """ Create results recod from calculated validation metrics """ results = {} results['MODEL'] = modelFileName results['BUOY'] = obsFileName results['YEAR'] = year results['MONTH'] = month results['PARAM'] = obsVarName for key in metrics.keys(): results[key] = metrics[key] return results
[docs] def store_valid_results(store, results): """ Store validation results record in the results dictionary """ store['MODEL'].append(results['MODEL']) store['OBSERVATION'].append(results['BUOY']) store['YEAR'].append(results['YEAR']) store['MONTH'].append(results['MONTH']) store['PARAM'].append(results['PARAM']) store['R'].append(results['R']) store['MB'].append(results['MB']) store['NMB'].append(results['NMB']) store['MAE'].append(results['MAE']) store['NMAE'].append(results['NMAE']) store['RMSE'].append(results['RMSE']) store['NRMSE'].append(results['NRMSE']) store['SI'].append(results['SI']) store['NSMPL'].append(results['NSMPL']) return
[docs] def display_tabulated_results(store, mVarName): """ Display validation results as a table on the standard output """ print('Validation Statistics for', mVarName, '\n') fmtstr = "{:<60}{:<25}{:^7}{:^7}{:^7}{:^9.3}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9}" for row in zip(*([key] + (value) for key, value in store.items())): print(fmtstr.format(*row)) return
[docs] def save_tabulated_results(store, platform, mVarName, outpath): """ Save tabulated validation results as an ASCII file """ outtxtname = platform+'_VALIDATION_STATS_'+mVarName+'.txt' outtxtfile = os.path.join(outpath,outtxtname) with open(outtxtfile,'w') as ofile: ofile.write('Validation Statistics for '+mVarName+'\n') fmtstr = "{:<60}{:<25}{:^7}{:^7}{:^7}{:^9.3}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9.4}{:^9}" for row in zip(*([key] + (value) for key, value in store.items())): ofile.write(fmtstr.format(*row)+'\n') outmatname = platform+'_VALIDATION_STATS_'+mVarName+'.mat' outmatfile = os.path.join(outpath,outmatname) saveFlg = sio.savemat(outmatfile,store) return saveFlg
[docs] def save_cleaned_ts_data(t, rscd, buoy, platform, mVarName, outpath, year, month): """ Save cleaned matched model and buoy timeseries as a matlab \\*.mat binary file """ yrMnthStr = '_'+str(year)+'_'+str(month).zfill(2) outdatname = platform+'_VALIDATION_TS_'+mVarName+yrMnthStr+'.mat' data = {} data['platform'] = platform data['var'] = mVarName data['year'] = year data['month'] = month data['time'] = t data['rscd'] = rscd data['buoy'] = buoy outdatfile = os.path.join(outpath, outdatname) saveFlg = sio.savemat(outdatfile, data) return saveFlg
[docs] def load_binary_results(platform, mVarName, datapath): """ load binary results from a matlab file """ matname = platform+'_VALIDATION_STATS_'+mVarName+'.mat' matfile = os.path.join(datapath,matname) data = sio.loadmat(matfile) store = initialize_results() for key in data.keys(): if key[0] != '_': if len(data[key].shape) == 1: store[key] = data[key].tolist() else: store[key] = data[key][0].tolist() return store
[docs] def process_record(rec, buoyFmt: str, platform: str, mVarName: str, varOptions: list, pVarName: str, results_dir: str, angularData=False, plot_results=False): """ Process a matched model/buoy data record """ rscdpath = copy.copy(rec[0]) rscdfname = copy.copy(rec[1]) rscdfile = os.path.join(rscdpath, rscdfname).replace('\\', '/') rscd, rscd_t, rscd_var, wmtw = extract_model(rscdfile, mVarName) #minNSampl = len(rscd_t)//2 minNSampl = 200 print(rscdfname) buoypath = copy.copy(rec[2]) buoyfname = copy.copy(rec[3]) buoyfile = os.path.join(buoypath, buoyfname).replace('\\', '/') bVarName, buoy, buoy_t, buoy_var = extract_obs(buoyfile, buoyFmt, varOptions) print('Buoy file = ', buoyfname) print('Buoy sampling period = ', buoy.sample_period) year = copy.copy(rec[4]) month = copy.copy(rec[5]) valid_results = None if bVarName != '': print('Variable name = ', bVarName) overlaps, tindx = wmtw.timeOverlap(buoy_t) if (tindx[0] is not None) & (tindx[1] is not None): ngood = len(np.where(buoy_var[tindx[0]:(tindx[1]+1)] != np.nan)[0]) else: ngood = 0 if overlaps & (ngood >= minNSampl): t0 = np.where(buoy_t >= rscd_t[0])[0][0] t1 = np.where(buoy_t <= rscd_t[-1])[0][-1] print('Number of good data points = ', ngood) print('Points in RSCD time series = ', len(rscd_t)) print('Points in Buoy time series = ', len(buoy_t[t0:(t1+1)])) print('Start times (model, in situ): ', rscd_t[0], buoy_t[t0]) print('End times (model, in situ): ', rscd_t[-1], buoy_t[t1]) model = [] obs = [] sample_period = np.median(np.diff(buoy_t[t0:(t1+1)]))*24.0 #if buoy.sample_period != 1.0: if sample_period != 1.0: # Get common indices max_dt = np.max(np.diff(buoy_t[t0:(t1+1)])) print('Maximum dt in obs =', max_dt*24., 'hours') # Interpolate obs onto model times t_mod = rscd_t.copy() v_mod = rscd_var.copy() t_obs = buoy_t[t0:(t1+1)].copy() v_obs = buoy_var[t0:(t1+1)].copy() v_obs = time_interpolate_obs(t_obs, v_obs, t_mod) # Plot time series if plot_results: fig = plot_time_series(t_mod, v_mod, t_mod, v_obs, bVarName, buoy.platform, pVarName, rscdfname, year, month, [rscd_t[0], rscd_t[-1]]) plotname = ''.join( ('RSCD_', buoy.platform, '_', pVarName,'_', str(year),'_', str(month).zfill(2), '_TS.png')) plotfile = os.path.join(results_dir,plotname) fig.savefig(plotfile) plot_close(fig) # Select common points between timeseries nans, x = np.isnan(v_obs), lambda z: z.nonzero()[0] indices = x(~nans) # Calculate validation metrics model = v_mod[indices].copy() obs = v_obs[indices].copy() nsampl = len(indices) save_cleaned_ts_data(t_mod, v_mod, v_obs, platform, mVarName, results_dir, year, month ) elif buoy.sample_period == 1.0: # Get common indices both = set(list(rscd_t)).intersection(list(buoy_t)) if len(both) > 0: # Interpolate obs onto model times t_mod = rscd_t.copy() v_mod = rscd_var.copy() t_obs = buoy_t[t0:(t1+1)].copy() v_obs = buoy_var[t0:(t1+1)].copy() # Plot time series if plot_results: fig = plot_time_series(t_mod, v_mod, t_obs, v_obs, bVarName, buoy.platform, pVarName, rscdfname, year, month, [rscd_t[0], rscd_t[-1]]) plotname = ''.join( ('RSCD_', buoy.platform, '_', pVarName,'_', str(year),'_', str(month).zfill(2), '_TS.png')) plotfile = os.path.join(results_dir,plotname) fig.savefig(plotfile) plot_close(fig) # Get common data indices nans, x = np.isnan(v_obs), lambda z: z.nonzero()[0] indices = x(~nans) # Select common points between timeseries model = v_mod[indices].copy() obs = v_obs[indices].copy() nsampl = len(both) save_cleaned_ts_data(t_mod, v_mod, v_obs, platform, mVarName, results_dir, year, month ) # Calculate validation metrics if (len(model) > minNSampl) & (len(obs) > minNSampl): if angularData: metrics = calculate_circular_metrics(model, obs, nsampl) else: metrics = calculate_metrics(model, obs, nsampl) valid_results = results_record(rscdfname, buoyfname, bVarName, year, month, metrics) # Plot correlation if plot_results: fig = plot_correlation(obs, model, bVarName, buoy.platform, pVarName, rscdfname, year, month, metrics) plotname = ''.join( ('RSCD_', buoy.platform, '_', pVarName,'_', str(year),'_', str(month).zfill(2), '_CORR.png')) plotfile = os.path.join(results_dir,plotname) fig.savefig(plotfile) plot_close(fig) else: print('No common indices found for metrics calculations.') print() else: print('Insuffient valid data for statistics. (# Good = ', ngood, ')') print() else: print('No variable match with ', varOptions) print('Buoy variables: ', buoy.file['vars']) print() return valid_results
[docs] def validate_records(records, dataFmt: str, platform: str, mVarName: str, varOptions: list, pVarName, results_dir, plot_results=False): """ Process a set of matched model/buoy data records """ # Create results dictionary results_store = initialize_results() if mVarName == 'dir': angularData=True else: angularData=False n_valid_results = 0 for rec in records: result = process_record(rec, dataFmt, platform, mVarName, varOptions, pVarName, results_dir, angularData, plot_results) if result is not None: store_valid_results(results_store, result) n_valid_results += 1 return n_valid_results, results_store