"""
===========================
Framework Utility Functions
===========================
Collection of utility functions shared by the ``vivarium`` framework.
"""
from __future__ import annotations
import functools
from bdb import BdbQuit
from collections.abc import Callable, Sequence
from importlib import import_module
from typing import Any, Literal, TypeVar, overload
import numpy as np
import pandas as pd
from loguru import logger
from vivarium.engine.types import NumberLike, NumericArray, Timedelta
TimeValue = TypeVar("TimeValue", bound=NumberLike)
@overload
def from_yearly(value: int, time_step: Timedelta) -> float:
...
@overload
def from_yearly(value: float, time_step: Timedelta) -> float:
...
@overload
def from_yearly(value: NumericArray, time_step: Timedelta) -> NumericArray:
...
@overload
def from_yearly(
value: pd.Series[int] | pd.Series[float], time_step: Timedelta
) -> pd.Series[float]:
...
@overload
def from_yearly(value: pd.DataFrame, time_step: Timedelta) -> pd.DataFrame:
...
[docs]
def from_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike:
"""Rescales a yearly rate to the size of a time step."""
return value * (time_step.total_seconds() / (60 * 60 * 24 * 365.0))
@overload
def to_yearly(value: int, time_step: Timedelta) -> float:
...
@overload
def to_yearly(value: float, time_step: Timedelta) -> float:
...
@overload
def to_yearly(value: NumericArray, time_step: Timedelta) -> NumericArray:
...
@overload
def to_yearly(
value: pd.Series[int] | pd.Series[float], time_step: Timedelta
) -> pd.Series[float]:
...
@overload
def to_yearly(value: pd.DataFrame, time_step: Timedelta) -> pd.DataFrame:
...
[docs]
def to_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike:
"""Converts a time-step-scaled rate back to a yearly rate."""
return value / (time_step.total_seconds() / (60 * 60 * 24 * 365.0))
[docs]
def rate_to_probability(
rate: Sequence[float] | NumberLike,
time_scaling_factor: float | int = 1.0,
rate_conversion_type: Literal["linear", "exponential"] = "linear",
) -> NumericArray:
"""Converts a rate to a probability.
Parameters
----------
rate
The rate to convert to a probability.
time_scaling_factor
The time factor in to scale the rates by. This is usually the time step.
rate_conversion_type
The type of conversion to use. Default is "linear" for a simple multiplcation
of rate and time_scaling_factor. The other option is "exponential" which should be
used for continuous time event driven models.
Returns
-------
An array of floats representing the probability of the converted rates.
Raises
------
ValueError
If an unsupported rate conversion type is provided.
Notes
-----
Beware machine-specific floating point issues. We have encountered underflow
when using the exponential conversion for rates greater than ~30,000. To avoid
this, we cap the rate at 250 when using the exponential conversion since
exp(-250) is effectively zero for practical purposes.
"""
if rate_conversion_type not in ["linear", "exponential"]:
raise ValueError(
f"Rate conversion type {rate_conversion_type} is not implemented. "
"Allowable types are 'linear' or 'exponential'."
)
probability: NumericArray
if rate_conversion_type == "linear":
# NOTE: The default behavior for randomness streams is to use a rate that is already
# scaled to the time step which is why the default time scaling factor is 1.0.
# Use asarray to handle both scalars and arrays
probability = np.asarray(rate) * time_scaling_factor
# Clip to 1.0 if the probability is greater than 1.0.
if np.any(probability > 1.0):
probability = np.clip(probability, None, 1.0)
logger.warning(
"The rate to probability conversion resulted in a probability greater than 1.0. "
"The probability has been clipped to 1.0 and indicates the rate is too high. "
)
else:
# NOTE: Cap the rate at 250 to avoid floating point underflow issues
rate = np.asarray(rate)
rate[rate > 250] = 250.0
probability = 1 - np.exp(-rate * time_scaling_factor)
return probability
[docs]
def probability_to_rate(
probability: Sequence[float] | NumberLike,
time_scaling_factor: float | int = 1.0,
rate_conversion_type: Literal["linear", "exponential"] = "linear",
) -> NumericArray:
"""Converts a probability to a rate.
Parameters
----------
probability
The probability to convert to a rate.
time_scaling_factor
The time factor in to scale the probability by. This is usually the time step.
rate_conversion_type
The type of conversion to use. Default is "linear" for a simple multiplcation
of rate and time_scaling_factor. The other option is "exponential" which should be
used for continuous time event driven models.
Returns
-------
An array of floats representing the rate of the converted probabilities.
Raises
------
ValueError
If an unsupported rate conversion type is provided.
"""
# NOTE: The default behavior for randomness streams is to use a rate that is already
# scaled to the time step which is why the default time scaling factor is 1.0.
if rate_conversion_type not in ["linear", "exponential"]:
raise ValueError(
f"Rate conversion type {rate_conversion_type} is not implemented. "
"Allowable types are 'linear' or 'exponential'."
)
rate: NumericArray
if rate_conversion_type == "linear":
# Use asarray to handle both scalars and arrays
rate = np.asarray(probability) / time_scaling_factor
else:
probability = np.asarray(probability)
rate = -np.log(1 - probability)
return rate
[docs]
def collapse_nested_dict(
d: dict[str, Any], prefix: str | None = None
) -> list[tuple[str, Any]]:
results = []
for k, v in d.items():
cur_prefix = prefix + "." + k if prefix else k
if isinstance(v, dict):
results.extend(collapse_nested_dict(v, prefix=cur_prefix))
else:
results.append((cur_prefix, v))
return results
[docs]
def import_by_path(path: str) -> Callable[..., Any]:
"""Imports a class or function given its absolute path.
Parameters
----------
path
Fully qualified dotted path to the object (e.g. "module.submodule.ClassName")
Returns
-------
The imported class or function.
"""
module_path, _, class_name = path.rpartition(".")
callable_attr: Callable[..., Any] = getattr(import_module(module_path), class_name)
return callable_attr
[docs]
def handle_exceptions(
func: Callable[..., Any], logger: Any, with_debugger: bool
) -> Callable[..., Any]:
"""Drops a user into an interactive debugger if func raises an error."""
@functools.wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
except (BdbQuit, KeyboardInterrupt):
raise
except Exception as e:
logger.exception("Uncaught exception {}".format(e))
if with_debugger:
import pdb
import traceback
traceback.print_exc()
pdb.post_mortem()
else:
raise
return wrapped