"""
Error Functions Module (erf)
============================
This module provides a collection of common error/loss functions for ODE parameter
optimization. These functions are designed to work with the invode optimization
framework and provide standardized metrics for model fitting.
All error functions follow the signature: error_func(y_pred) -> float
where y_pred is the model prediction and the function returns a scalar error value.
"""
import numpy as np
from typing import Union, Optional, Callable, Tuple
import warnings
[docs]
class ErrorFunction:
"""
Base class for error functions with data storage and validation.
This class handles common functionality like data storage, validation,
and provides a consistent interface for all error functions.
"""
def __init__(self, data: np.ndarray, **kwargs):
"""
Initialize error function with reference data.
Parameters
----------
data : np.ndarray
Reference/observed data to compare predictions against.
"""
self.data = np.asarray(data)
if self.data.size == 0:
raise ValueError("Data array cannot be empty")
def __call__(self, y_pred: np.ndarray) -> float:
"""
Compute error between prediction and reference data.
Parameters
----------
y_pred : np.ndarray
Model predictions to compare against reference data.
Returns
-------
float
Computed error value.
"""
raise NotImplementedError("Subclasses must implement __call__ method")
def _validate_prediction(self, y_pred: np.ndarray) -> np.ndarray:
"""Validate and format prediction array."""
y_pred = np.asarray(y_pred)
if y_pred.shape != self.data.shape:
raise ValueError(f"Prediction shape {y_pred.shape} does not match "
f"data shape {self.data.shape}")
if not np.isfinite(y_pred).all():
warnings.warn("Prediction contains non-finite values")
return np.inf
return y_pred
[docs]
class MSE(ErrorFunction):
"""
Mean Squared Error (MSE) function.
Computes: MSE = (1/n) * Σ(y_pred - y_data)²
This is the most common error function for continuous regression problems.
It penalizes large errors more heavily than small ones due to the squaring.
Examples
--------
>>> import numpy as np
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> mse_func = MSE(data)
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = mse_func(prediction)
>>> print(f"MSE: {error:.4f}")
"""
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute Mean Squared Error."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
residuals = y_pred - self.data
return float(np.mean(residuals**2))
return y_pred
[docs]
class ChiSquaredMSE(ErrorFunction):
"""
Chi-squared weighted Mean Squared Error.
Computes: χ² = Σ((y_pred - y_data)² / σ²)
This error function weights residuals by their expected variance (σ²),
making it appropriate when different data points have different uncertainties.
Parameters
----------
data : np.ndarray
Reference/observed data.
sigma : np.ndarray or float
Standard deviation/uncertainty for each data point. If float,
assumes constant uncertainty across all points.
normalize : bool, optional
If True, normalize by number of data points (default False).
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> sigma = np.array([0.1, 0.2, 0.1, 0.3]) # Different uncertainties
>>> chi2_func = ChiSquaredMSE(data, sigma=sigma)
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = chi2_func(prediction)
>>> print(f"Chi-squared: {error:.4f}")
>>> # Constant uncertainty
>>> chi2_func_const = ChiSquaredMSE(data, sigma=0.2)
"""
def __init__(self, data: np.ndarray, sigma: Union[np.ndarray, float],
normalize: bool = False):
super().__init__(data)
if np.isscalar(sigma):
self.sigma = np.full_like(self.data, float(sigma))
else:
self.sigma = np.asarray(sigma)
if self.sigma.shape != self.data.shape:
raise ValueError(f"Sigma shape {self.sigma.shape} does not match "
f"data shape {self.data.shape}")
if (self.sigma <= 0).any():
raise ValueError("All sigma values must be positive")
self.normalize = normalize
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute Chi-squared weighted error."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
residuals = y_pred - self.data
chi_squared = np.sum((residuals / self.sigma)**2)
if self.normalize:
chi_squared /= len(self.data)
return float(chi_squared)
return y_pred
[docs]
class LogLikelihood(ErrorFunction):
"""
Gaussian Log-Likelihood Error Function.
Computes the log-likelihood of the predicted values `y_pred` under the
assumption that the observed data `data` follows a Gaussian distribution
with mean equal to `y_pred` and constant variance `sigma^2`.
The log-likelihood is given by:
LL(μ, σ²) = -n/2 * log(2πσ²) - 1/(2σ²) * Σ(yi - μ)²
Parameters
----------
data : np.ndarray
Observed data points.
sigma : float
Standard deviation of the Gaussian noise. Must be positive.
Raises
------
ValueError
If sigma is not positive or if the data array is empty.
"""
def __init__(self, data: np.ndarray, sigma: float, **kwargs):
super().__init__(data, **kwargs)
if sigma <= 0:
raise ValueError("Standard deviation sigma must be positive")
self.sigma = sigma
self.n = self.data.size
def __call__(self, y_pred: np.ndarray) -> float:
y_pred = self._validate_prediction(y_pred)
if not np.isfinite(y_pred).all():
return -np.inf # Return log-likelihood as -inf if prediction is invalid
residuals = self.data - y_pred
squared_error = np.sum(residuals**2)
ll = -0.5 * self.n * np.log(2 * np.pi * self.sigma**2) - (0.5 / self.sigma**2) * squared_error
return ll
[docs]
class MAE(ErrorFunction):
"""
Mean Absolute Error (MAE) function.
Computes: MAE = (1/n) * Σ|y_pred - y_data|
MAE is more robust to outliers than MSE since it doesn't square the residuals.
It provides a linear penalty for errors.
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> mae_func = MAE(data)
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = mae_func(prediction)
>>> print(f"MAE: {error:.4f}")
"""
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute Mean Absolute Error."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
residuals = np.abs(y_pred - self.data)
return float(np.mean(residuals))
return y_pred
[docs]
class RMSE(ErrorFunction):
"""
Root Mean Squared Error (RMSE) function.
Computes: RMSE = √((1/n) * Σ(y_pred - y_data)²)
RMSE is in the same units as the original data, making it more interpretable
than MSE while maintaining the same optimization properties.
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> rmse_func = RMSE(data)
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = rmse_func(prediction)
>>> print(f"RMSE: {error:.4f}")
"""
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute Root Mean Squared Error."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
residuals = y_pred - self.data
mse = np.mean(residuals**2)
return float(np.sqrt(mse))
return y_pred
[docs]
class HuberLoss(ErrorFunction):
"""
Huber Loss function (robust regression).
Combines the best properties of MSE and MAE:
- Quadratic for small errors (|error| <= delta)
- Linear for large errors (|error| > delta)
This makes it less sensitive to outliers than MSE while maintaining
smoothness for optimization.
Parameters
----------
data : np.ndarray
Reference/observed data.
delta : float, optional
Threshold for switching between quadratic and linear loss. Default is 1.0.
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> huber_func = HuberLoss(data, delta=0.5)
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = huber_func(prediction)
>>> print(f"Huber Loss: {error:.4f}")
"""
def __init__(self, data: np.ndarray, delta: float = 1.0):
super().__init__(data)
if delta <= 0:
raise ValueError("Delta must be positive")
self.delta = delta
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute Huber Loss."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
residuals = np.abs(y_pred - self.data)
# Quadratic for small errors, linear for large errors
quadratic_mask = residuals <= self.delta
quadratic_loss = 0.5 * residuals[quadratic_mask]**2
linear_loss = self.delta * (residuals[~quadratic_mask] - 0.5 * self.delta)
total_loss = np.sum(quadratic_loss) + np.sum(linear_loss)
return float(total_loss / len(self.data))
return y_pred
[docs]
class RegularizedError(ErrorFunction):
"""
Error function with L1, L2, or elastic net regularization.
Combines a base error function with parameter regularization:
Total Error = Base Error + λ₁ * L1_penalty + λ₂ * L2_penalty
This is useful for preventing overfitting and promoting sparse solutions.
Parameters
----------
data : np.ndarray
Reference/observed data.
base_error : str or ErrorFunction
Base error function ('mse', 'mae', 'rmse') or custom ErrorFunction instance.
l1_lambda : float, optional
L1 regularization strength (promotes sparsity). Default is 0.0.
l2_lambda : float, optional
L2 regularization strength (promotes smoothness). Default is 0.0.
param_getter : callable, optional
Function to extract parameters for regularization. If None,
regularization is not applied (requires external parameter passing).
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> reg_func = RegularizedError(data, 'mse', l1_lambda=0.01, l2_lambda=0.1)
>>>
>>> # Usage in optimization (parameters passed externally)
>>> def error_with_params(y_pred, params):
... base_error = reg_func(y_pred)
... l1_penalty = np.sum(np.abs(list(params.values())))
... l2_penalty = np.sum([p**2 for p in params.values()])
... return base_error + 0.01 * l1_penalty + 0.1 * l2_penalty
"""
def __init__(self, data: np.ndarray, base_error: Union[str, ErrorFunction] = 'mse',
l1_lambda: float = 0.0, l2_lambda: float = 0.0,
param_getter: Optional[Callable] = None):
super().__init__(data)
# Initialize base error function
if isinstance(base_error, str):
error_map = {
'mse': MSE(data),
'mae': MAE(data),
'rmse': RMSE(data)
}
if base_error not in error_map:
raise ValueError(f"Unknown base error: {base_error}")
self.base_error = error_map[base_error]
else:
self.base_error = base_error
self.l1_lambda = l1_lambda
self.l2_lambda = l2_lambda
self.param_getter = param_getter
if l1_lambda < 0 or l2_lambda < 0:
raise ValueError("Regularization parameters must be non-negative")
def __call__(self, y_pred: np.ndarray, params: Optional[dict] = None) -> float:
"""
Compute regularized error.
Parameters
----------
y_pred : np.ndarray
Model predictions.
params : dict, optional
Parameter dictionary for regularization. If None, only base error is computed.
Returns
-------
float
Total error including regularization terms.
"""
base_error_val = self.base_error(y_pred)
if params is None:
if self.l1_lambda > 0 or self.l2_lambda > 0:
warnings.warn("Regularization requested but no parameters provided")
return base_error_val
# Compute regularization terms
param_values = np.array(list(params.values()))
l1_penalty = 0.0
l2_penalty = 0.0
if self.l1_lambda > 0:
l1_penalty = self.l1_lambda * np.sum(np.abs(param_values))
if self.l2_lambda > 0:
l2_penalty = self.l2_lambda * np.sum(param_values**2)
return float(base_error_val + l1_penalty + l2_penalty)
[docs]
class WeightedError(ErrorFunction):
"""
Weighted error function for handling different importance of data points.
Applies weights to individual data points, allowing some measurements
to contribute more to the total error than others.
Parameters
----------
data : np.ndarray
Reference/observed data.
weights : np.ndarray
Weights for each data point. Higher weights = more importance.
base_error : str, optional
Base error type ('mse', 'mae'). Default is 'mse'.
Examples
--------
>>> data = np.array([1.0, 2.0, 3.0, 4.0])
>>> weights = np.array([1.0, 2.0, 1.0, 0.5]) # Different importance
>>> weighted_func = WeightedError(data, weights, 'mse')
>>> prediction = np.array([1.1, 2.2, 2.8, 4.1])
>>> error = weighted_func(prediction)
>>> print(f"Weighted MSE: {error:.4f}")
"""
def __init__(self, data: np.ndarray, weights: np.ndarray, base_error: str = 'mse'):
super().__init__(data)
self.weights = np.asarray(weights)
if self.weights.shape != self.data.shape:
raise ValueError(f"Weights shape {self.weights.shape} does not match "
f"data shape {self.data.shape}")
if (self.weights < 0).any():
raise ValueError("All weights must be non-negative")
# Normalize weights to sum to number of data points
weight_sum = np.sum(self.weights)
if weight_sum > 0:
self.weights = self.weights * len(self.data) / weight_sum
else:
raise ValueError("Sum of weights must be positive")
self.base_error = base_error
def __call__(self, y_pred: np.ndarray) -> float:
"""Compute weighted error."""
y_pred = self._validate_prediction(y_pred)
if not np.isscalar(y_pred):
if self.base_error == 'mse':
residuals = (y_pred - self.data)**2
elif self.base_error == 'mae':
residuals = np.abs(y_pred - self.data)
else:
raise ValueError(f"Unknown base error: {self.base_error}")
weighted_error = np.sum(self.weights * residuals) / len(self.data)
return float(weighted_error)
return y_pred
# Convenience functions for backward compatibility and ease of use
[docs]
def mse(data: np.ndarray) -> MSE:
"""
Will be deprecated in future versions.
Create MSE error function.
Parameters
----------
data : np.ndarray
Reference data for comparison.
Returns
-------
MSE
Configured MSE error function.
Examples
--------
>>> import numpy as np
>>> data = np.array([1.0, 2.0, 3.0])
>>> error_func = mse(data)
>>> prediction = np.array([1.1, 2.1, 2.9])
>>> error = error_func(prediction)
"""
return MSE(data)
[docs]
def chisquared(data: np.ndarray, sigma: Union[np.ndarray, float],
normalize: bool = False) -> ChiSquaredMSE:
"""
Will be deprecated in future versions.
Create Chi-squared error function.
Parameters
----------
data : np.ndarray
Reference data.
sigma : np.ndarray or float
Standard deviation for each data point.
normalize : bool, optional
Whether to normalize by number of points.
Returns
-------
ChiSquaredMSE
Configured Chi-squared error function.
"""
return ChiSquaredMSE(data, sigma=sigma, normalize=normalize)
[docs]
def mae(data: np.ndarray) -> MAE:
"""
Will be deprecated in future versions.
Create MAE error function."""
return MAE(data)
[docs]
def rmse(data: np.ndarray) -> RMSE:
"""
Will be deprecated in future versions.Create RMSE error function."""
return RMSE(data)
[docs]
def huber(data: np.ndarray, delta: float = 1.0) -> HuberLoss:
"""
Will be deprecated in future versions.Create Huber loss error function."""
return HuberLoss(data, delta=delta)