"""
=================
Results Interface
=================
This module provides an interface to the :class:`ResultsManager <vivarium.engine.framework.results.manager.ResultsManager>`.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy
from vivarium.engine.framework.event import Event
from vivarium.engine.framework.lifecycle import lifecycle_states
from vivarium.engine.framework.results.observation import (
AddingObservation,
ConcatenatingObservation,
StratifiedObservation,
UnstratifiedObservation,
)
from vivarium.engine.manager import Interface
from vivarium.engine.types import ScalarMapper, VectorMapper
if TYPE_CHECKING:
from vivarium.engine.framework.results.manager import ResultsManager
ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
"""A Callable that takes existing results and new observations and returns updated results."""
ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame]
"""A Callable that takes a measure as a string and a DataFrame of observation results
and returns formatted results."""
ResultsGathererInput = Union[
pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None # type: ignore [type-arg]
]
ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame]
"""A Callable that optionally takes a possibly stratified population and returns
new observation results."""
def _required_function_placeholder(
*args: ResultsGathererInput
| tuple[pd.DataFrame, pd.DataFrame]
| tuple[str, pd.DataFrame],
**kwargs: Any,
) -> pd.DataFrame:
"""Returns and empty dataframe.
Placeholder function to indicate that a required function is missing.
"""
return pd.DataFrame()
def _convert_object_cols_to_categorical(results: pd.DataFrame) -> pd.DataFrame:
"""Convert object dtype columns to categorical dtype."""
object_cols = results.select_dtypes(include=["object"]).columns
results[object_cols] = results[object_cols].astype("category")
return results
def _default_stratified_observation_formatter(
measure: str, results: pd.DataFrame
) -> pd.DataFrame:
"""Reset the results index and convert object columns to categorical dtype."""
return _convert_object_cols_to_categorical(results.reset_index())
def _default_unstratified_observation_formatter(
measure: str, results: pd.DataFrame
) -> pd.DataFrame:
"""Return the results unchanged."""
return results
[docs]
class PopulationFilter(NamedTuple):
"""Container class for population query string and include_untracked flag."""
query: str = ""
include_untracked: bool = False
[docs]
class ResultsInterface(Interface):
"""Builder interface for the results management system.
The results management system allows users to delegate results production
to the simulation framework. This process attempts to roughly mimic the
groupby-apply logic commonly done when manipulating :mod:`pandas`
DataFrames. Good encapsulation of simulation logic typically has
results production separated from the modeling code into specialized
`Observer` components. This often highlights the need for transformations
of the simulation state into representations that aren't needed for
modeling, but are required for the stratification of produced results.
The purpose of this interface is to provide controlled access to a results
backend by means of the builder object; it exposes methods to register both
stratifications and results producers (referred to as "observations").
"""
def __init__(self, manager: ResultsManager) -> None:
self._manager: ResultsManager = manager
self._name = "results_interface"
@property
def name(self) -> str:
return self._name
##################################
# Stratification-related methods #
##################################
[docs]
def register_stratification(
self,
name: str,
categories: list[str],
excluded_categories: list[str] | None = None,
mapper: VectorMapper | ScalarMapper | None = None,
is_vectorized: bool = False,
requires_attributes: list[str] = [],
) -> None:
"""Registers a stratification that can be used by stratified observations.
Parameters
----------
name
Name of the stratification.
categories
Exhaustive list of all possible stratification values.
excluded_categories
List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration.
mapper
A callable that maps the population attributes specified by the
`requires_attributes` argumnt to the stratification categories. It can
either map the entire population or an individual simulant. A simulation
will fail if the `mapper` ever produces an invalid value.
is_vectorized
True if the `mapper` function will map the entire population, and False
if it will only map a single simulant.
requires_attributes
The population attributes that are required by the `mapper` to produce
the stratification.
"""
self._manager.register_stratification(
name,
categories,
excluded_categories,
mapper,
is_vectorized,
requires_attributes,
)
[docs]
def register_binned_stratification(
self,
target: str,
binned_column: str,
bin_edges: Sequence[int | float] = [],
labels: list[str] = [],
excluded_categories: list[str] | None = None,
**cut_kwargs: int | str | bool,
) -> None:
"""Registers a binned stratification that can be used by stratified observations.
Parameters
----------
target
Name of the population attribute to be binned.
binned_column
Name of the (binned) stratification.
bin_edges
List of scalars defining the bin edges, passed to :meth: pandas.cut.
The length must be equal to the length of `labels` plus 1.
labels
List of string labels for bins. The length must be equal to the length
of `bin_edges` minus 1.
excluded_categories
List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration.
**cut_kwargs
Keyword arguments for :meth: pandas.cut.
"""
self._manager.register_binned_stratification(
target,
binned_column,
bin_edges,
labels,
excluded_categories,
**cut_kwargs,
)
###############################
# Observation-related methods #
###############################
[docs]
def register_stratified_observation(
self,
name: str,
pop_filter: str = "",
include_untracked: bool = False,
when: str = lifecycle_states.COLLECT_METRICS,
requires_attributes: list[str] = [],
results_updater: ResultsUpdater = _required_function_placeholder,
results_formatter: ResultsFormatter = _default_stratified_observation_formatter,
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
aggregator_sources: list[str] | None = None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[Any]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers a stratified observation to the results system.
Parameters
----------
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
pop_filter
A Pandas query filter string to filter the population down to the simulants who should
be considered for the observation.
include_untracked
Whether to include simulants who are untracked from this observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required by the `aggregator`.
results_updater
Function that updates existing raw observation results with newly gathered results.
results_formatter
Function that formats the raw observation results.
additional_stratifications
List of additional :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>`
names by which to stratify this observation by.
excluded_stratifications
List of default :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>`
names to remove from this observation.
aggregator_sources
List of population view columns to be used in the `aggregator`.
aggregator
Function that computes the quantity for this observation.
to_observe
Function that determines whether to perform an observation on this Event.
Raises
------
ValueError
If any required callable arguments are missing.
"""
self._check_for_required_callables(name, {"results_updater": results_updater})
self._manager.register_observation(
observation_type=StratifiedObservation,
name=name,
population_filter=PopulationFilter(pop_filter, include_untracked),
when=when,
requires_attributes=requires_attributes,
results_updater=results_updater,
results_formatter=results_formatter,
additional_stratifications=additional_stratifications,
excluded_stratifications=excluded_stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
to_observe=to_observe,
)
[docs]
def register_unstratified_observation(
self,
name: str,
pop_filter: str = "",
include_untracked: bool = False,
when: str = lifecycle_states.COLLECT_METRICS,
requires_attributes: list[str] = [],
results_gatherer: ResultsGatherer = _required_function_placeholder,
results_updater: ResultsUpdater = _required_function_placeholder,
results_formatter: ResultsFormatter = _default_unstratified_observation_formatter,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers an unstratified observation to the results system.
Parameters
----------
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
pop_filter
A Pandas query filter string to filter the population down to the simulants who should
be considered for the observation.
include_untracked
Whether to include simulants who are untracked from this observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required by the `results_gatherer`.
results_gatherer
Function that gathers the latest observation results.
results_updater
Function that updates existing raw observation results with newly gathered results.
results_formatter
Function that formats the raw observation results.
to_observe
Function that determines whether to perform an observation on this Event.
Raises
------
ValueError
If any required callable arguments are missing.
"""
required_callables: dict[str, Callable[..., pd.DataFrame]] = {
"results_gatherer": results_gatherer,
"results_updater": results_updater,
}
self._check_for_required_callables(name, required_callables)
self._manager.register_observation(
observation_type=UnstratifiedObservation,
name=name,
population_filter=PopulationFilter(pop_filter, include_untracked),
when=when,
requires_attributes=requires_attributes,
results_updater=results_updater,
results_gatherer=results_gatherer,
results_formatter=results_formatter,
to_observe=to_observe,
)
[docs]
def register_adding_observation(
self,
name: str,
pop_filter: str = "",
include_untracked: bool = False,
when: str = lifecycle_states.COLLECT_METRICS,
requires_attributes: list[str] = [],
results_formatter: ResultsFormatter = _default_stratified_observation_formatter,
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
aggregator_sources: list[str] | None = None,
aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]] = len,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers an adding observation to the results system.
An "adding" observation is one that adds/sums new results to existing
result values.
Notes
-----
An adding observation is a specific type of stratified observation.
Parameters
----------
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
pop_filter
A Pandas query filter string to filter the population down to the simulants who should
be considered for the observation.
include_untracked
Whether to include simulants who are untracked from this observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required by the `aggregator`.
results_formatter
Function that formats the raw observation results.
additional_stratifications
List of additional :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>`
names by which to stratify this observation by.
excluded_stratifications
List of default :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>`
names to remove from this observation.
aggregator_sources
List of population view columns to be used in the `aggregator`.
aggregator
Function that computes the quantity for this observation.
to_observe
Function that determines whether to perform an observation on this Event.
"""
self._manager.register_observation(
observation_type=AddingObservation,
name=name,
population_filter=PopulationFilter(pop_filter, include_untracked),
when=when,
requires_attributes=requires_attributes,
results_formatter=results_formatter,
additional_stratifications=additional_stratifications,
excluded_stratifications=excluded_stratifications,
aggregator_sources=aggregator_sources,
aggregator=aggregator,
to_observe=to_observe,
)
[docs]
def register_concatenating_observation(
self,
name: str,
pop_filter: str = "",
include_untracked: bool = False,
when: str = lifecycle_states.COLLECT_METRICS,
requires_attributes: list[str] = [],
results_formatter: ResultsFormatter = _default_unstratified_observation_formatter,
to_observe: Callable[[Event], bool] = lambda event: True,
) -> None:
"""Registers a concatenating observation to the results system.
A "concatenating" observation is one that concatenates new results to
existing results.
Notes
-----
A concatenating observation is a specific type of unstratified observation.
Parameters
----------
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
pop_filter
A Pandas query filter string to filter the population down to the simulants who should
be considered for the observation.
include_untracked
Whether to include simulants who are untracked from this observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required by the `aggregator`.
results_formatter
Function that formats the raw observation results.
to_observe
Function that determines whether to perform an observation on this Event.
"""
self._manager.register_observation(
observation_type=ConcatenatingObservation,
name=name,
population_filter=PopulationFilter(pop_filter, include_untracked),
when=when,
requires_attributes=requires_attributes,
results_formatter=results_formatter,
to_observe=to_observe,
)
@staticmethod
def _check_for_required_callables(
observation_name: str,
required_callables: dict[str, ResultsFormatter | ResultsGatherer | ResultsUpdater],
) -> None:
"""Raises a ValueError if any required callable arguments are missing."""
missing = []
for arg_name, callable in required_callables.items():
if callable == _required_function_placeholder:
missing.append(arg_name)
if len(missing) > 0:
raise ValueError(
f"Observation '{observation_name}' is missing required callable(s): {missing}"
)