Source code for boa.metrics.metric_funcs

"""
########################
Metric Functions
########################

Functions used for Metrics

"""
from __future__ import annotations

import numpy as np
import scipy.stats as stats
import sklearn.metrics
from sklearn.metrics import __all__ as sklearn_all
from sklearn.metrics import mean_squared_error

from boa.logger import get_logger
from boa.utils import get_dictionary_from_callable

logger = get_logger()


[docs]def normalized_root_mean_squared_error(y_true, y_pred, normalizer="iqr", **kwargs): """Normalized root mean squared error Parameters ---------- y_true : array_like With shape (n_samples,) or (n_samples, n_outputs) Ground truth (correct) target values. y_pred : array_like With shape (n_samples,) or (n_samples, n_outputs) Estimated target values. normalizer : str How to normalize the RMSE, options include iqr, std, mean, and range. (default iqr) **kwargs see sklearn.metrics.mean_squared_error for additional options Returns ------- nrmse : float or numpy.ndarray[float] A normalized version of RMSE """ rmse = mean_squared_error(y_true, y_pred, squared=False, **get_dictionary_from_callable(mean_squared_error, kwargs)) if normalizer == "iqr": norm = stats.iqr(y_pred) elif normalizer == "std": norm = stats.tstd(y_pred) elif normalizer == "mean": norm = stats.tmean(y_pred) elif normalizer == "range": norm = np.ptp(y_pred) else: raise ValueError("normalizer must be 'iqr', 'std', 'mean', or 'range'.") nrmse = rmse / norm return nrmse
[docs]def setup_sklearn_metric(metric_to_eval, instantiate=True, **kw): import boa.metrics.metrics def modular_sklearn_metric(**kwargs): return boa.metrics.metrics.BOASklearnMetric( **{"name": metric_to_eval, **kw, **kwargs, "metric_to_eval": metric_to_eval} ) return modular_sklearn_metric(**kw) if instantiate else modular_sklearn_metric
[docs]def get_sklearn_func(metric_to_eval): if metric_to_eval in sklearn_all: metric = getattr(sklearn.metrics, metric_to_eval) # we also check the attribute name incase metric_to_eval is actual a class b/c ModularMetric # has been cloned elif getattr(metric_to_eval, "name", None) in sklearn_all: metric = getattr(sklearn.metrics, metric_to_eval.name) else: raise AttributeError(f"Sklearn metric: {metric_to_eval} not found!") return metric