Source code for vivarium.engine.framework.results.interface

"""
=================
Results Interface
=================

This module provides an interface to the :class:`ResultsManager <vivarium.engine.framework.results.manager.ResultsManager>`.

"""
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union

import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy

from vivarium.engine.framework.event import Event
from vivarium.engine.framework.lifecycle import lifecycle_states
from vivarium.engine.framework.results.observation import (
    AddingObservation,
    ConcatenatingObservation,
    StratifiedObservation,
    UnstratifiedObservation,
)
from vivarium.engine.manager import Interface
from vivarium.engine.types import ScalarMapper, VectorMapper

if TYPE_CHECKING:
    from vivarium.engine.framework.results.manager import ResultsManager


ResultsUpdater = Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
"""A Callable that takes existing results and new observations and returns updated results."""
ResultsFormatter = Callable[[str, pd.DataFrame], pd.DataFrame]
"""A Callable that takes a measure as a string and a DataFrame of observation results
and returns formatted results."""
ResultsGathererInput = Union[
    pd.DataFrame, DataFrameGroupBy, tuple[str, ...], None  # type: ignore [type-arg]
]
ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame]
"""A Callable that optionally takes a possibly stratified population and returns
new observation results."""


def _required_function_placeholder(
    *args: ResultsGathererInput
    | tuple[pd.DataFrame, pd.DataFrame]
    | tuple[str, pd.DataFrame],
    **kwargs: Any,
) -> pd.DataFrame:
    """Returns and empty dataframe.

    Placeholder function to indicate that a required function is missing.
    """
    return pd.DataFrame()


def _convert_object_cols_to_categorical(results: pd.DataFrame) -> pd.DataFrame:
    """Convert object dtype columns to categorical dtype."""
    object_cols = results.select_dtypes(include=["object"]).columns
    results[object_cols] = results[object_cols].astype("category")
    return results


def _default_stratified_observation_formatter(
    measure: str, results: pd.DataFrame
) -> pd.DataFrame:
    """Reset the results index and convert object columns to categorical dtype."""
    return _convert_object_cols_to_categorical(results.reset_index())


def _default_unstratified_observation_formatter(
    measure: str, results: pd.DataFrame
) -> pd.DataFrame:
    """Return the results unchanged."""
    return results


[docs] class PopulationFilter(NamedTuple): """Container class for population query string and include_untracked flag.""" query: str = "" include_untracked: bool = False
[docs] class ResultsInterface(Interface): """Builder interface for the results management system. The results management system allows users to delegate results production to the simulation framework. This process attempts to roughly mimic the groupby-apply logic commonly done when manipulating :mod:`pandas` DataFrames. Good encapsulation of simulation logic typically has results production separated from the modeling code into specialized `Observer` components. This often highlights the need for transformations of the simulation state into representations that aren't needed for modeling, but are required for the stratification of produced results. The purpose of this interface is to provide controlled access to a results backend by means of the builder object; it exposes methods to register both stratifications and results producers (referred to as "observations"). """ def __init__(self, manager: ResultsManager) -> None: self._manager: ResultsManager = manager self._name = "results_interface" @property def name(self) -> str: return self._name ################################## # Stratification-related methods # ##################################
[docs] def register_stratification( self, name: str, categories: list[str], excluded_categories: list[str] | None = None, mapper: VectorMapper | ScalarMapper | None = None, is_vectorized: bool = False, requires_attributes: list[str] = [], ) -> None: """Registers a stratification that can be used by stratified observations. Parameters ---------- name Name of the stratification. categories Exhaustive list of all possible stratification values. excluded_categories List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. mapper A callable that maps the population attributes specified by the `requires_attributes` argumnt to the stratification categories. It can either map the entire population or an individual simulant. A simulation will fail if the `mapper` ever produces an invalid value. is_vectorized True if the `mapper` function will map the entire population, and False if it will only map a single simulant. requires_attributes The population attributes that are required by the `mapper` to produce the stratification. """ self._manager.register_stratification( name, categories, excluded_categories, mapper, is_vectorized, requires_attributes, )
[docs] def register_binned_stratification( self, target: str, binned_column: str, bin_edges: Sequence[int | float] = [], labels: list[str] = [], excluded_categories: list[str] | None = None, **cut_kwargs: int | str | bool, ) -> None: """Registers a binned stratification that can be used by stratified observations. Parameters ---------- target Name of the population attribute to be binned. binned_column Name of the (binned) stratification. bin_edges List of scalars defining the bin edges, passed to :meth: pandas.cut. The length must be equal to the length of `labels` plus 1. labels List of string labels for bins. The length must be equal to the length of `bin_edges` minus 1. excluded_categories List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration. **cut_kwargs Keyword arguments for :meth: pandas.cut. """ self._manager.register_binned_stratification( target, binned_column, bin_edges, labels, excluded_categories, **cut_kwargs, )
############################### # Observation-related methods # ###############################
[docs] def register_stratified_observation( self, name: str, pop_filter: str = "", include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, requires_attributes: list[str] = [], results_updater: ResultsUpdater = _required_function_placeholder, results_formatter: ResultsFormatter = _default_stratified_observation_formatter, additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, aggregator: Callable[[pd.DataFrame], float | pd.Series[Any]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a stratified observation to the results system. Parameters ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. include_untracked Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes that are required by the `aggregator`. results_updater Function that updates existing raw observation results with newly gathered results. results_formatter Function that formats the raw observation results. additional_stratifications List of additional :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>` names by which to stratify this observation by. excluded_stratifications List of default :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>` names to remove from this observation. aggregator_sources List of population view columns to be used in the `aggregator`. aggregator Function that computes the quantity for this observation. to_observe Function that determines whether to perform an observation on this Event. Raises ------ ValueError If any required callable arguments are missing. """ self._check_for_required_callables(name, {"results_updater": results_updater}) self._manager.register_observation( observation_type=StratifiedObservation, name=name, population_filter=PopulationFilter(pop_filter, include_untracked), when=when, requires_attributes=requires_attributes, results_updater=results_updater, results_formatter=results_formatter, additional_stratifications=additional_stratifications, excluded_stratifications=excluded_stratifications, aggregator_sources=aggregator_sources, aggregator=aggregator, to_observe=to_observe, )
[docs] def register_unstratified_observation( self, name: str, pop_filter: str = "", include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, requires_attributes: list[str] = [], results_gatherer: ResultsGatherer = _required_function_placeholder, results_updater: ResultsUpdater = _required_function_placeholder, results_formatter: ResultsFormatter = _default_unstratified_observation_formatter, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an unstratified observation to the results system. Parameters ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. include_untracked Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes that are required by the `results_gatherer`. results_gatherer Function that gathers the latest observation results. results_updater Function that updates existing raw observation results with newly gathered results. results_formatter Function that formats the raw observation results. to_observe Function that determines whether to perform an observation on this Event. Raises ------ ValueError If any required callable arguments are missing. """ required_callables: dict[str, Callable[..., pd.DataFrame]] = { "results_gatherer": results_gatherer, "results_updater": results_updater, } self._check_for_required_callables(name, required_callables) self._manager.register_observation( observation_type=UnstratifiedObservation, name=name, population_filter=PopulationFilter(pop_filter, include_untracked), when=when, requires_attributes=requires_attributes, results_updater=results_updater, results_gatherer=results_gatherer, results_formatter=results_formatter, to_observe=to_observe, )
[docs] def register_adding_observation( self, name: str, pop_filter: str = "", include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, requires_attributes: list[str] = [], results_formatter: ResultsFormatter = _default_stratified_observation_formatter, additional_stratifications: list[str] = [], excluded_stratifications: list[str] = [], aggregator_sources: list[str] | None = None, aggregator: Callable[[pd.DataFrame], int | float | pd.Series[int | float]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an adding observation to the results system. An "adding" observation is one that adds/sums new results to existing result values. Notes ----- An adding observation is a specific type of stratified observation. Parameters ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. include_untracked Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes that are required by the `aggregator`. results_formatter Function that formats the raw observation results. additional_stratifications List of additional :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>` names by which to stratify this observation by. excluded_stratifications List of default :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>` names to remove from this observation. aggregator_sources List of population view columns to be used in the `aggregator`. aggregator Function that computes the quantity for this observation. to_observe Function that determines whether to perform an observation on this Event. """ self._manager.register_observation( observation_type=AddingObservation, name=name, population_filter=PopulationFilter(pop_filter, include_untracked), when=when, requires_attributes=requires_attributes, results_formatter=results_formatter, additional_stratifications=additional_stratifications, excluded_stratifications=excluded_stratifications, aggregator_sources=aggregator_sources, aggregator=aggregator, to_observe=to_observe, )
[docs] def register_concatenating_observation( self, name: str, pop_filter: str = "", include_untracked: bool = False, when: str = lifecycle_states.COLLECT_METRICS, requires_attributes: list[str] = [], results_formatter: ResultsFormatter = _default_unstratified_observation_formatter, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a concatenating observation to the results system. A "concatenating" observation is one that concatenates new results to existing results. Notes ----- A concatenating observation is a specific type of unstratified observation. Parameters ---------- name Name of the observation. It will also be the name of the output results file for this particular observation. pop_filter A Pandas query filter string to filter the population down to the simulants who should be considered for the observation. include_untracked Whether to include simulants who are untracked from this observation. when Name of the lifecycle phase the observation should happen. Valid values are: "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". requires_attributes The population attributes that are required by the `aggregator`. results_formatter Function that formats the raw observation results. to_observe Function that determines whether to perform an observation on this Event. """ self._manager.register_observation( observation_type=ConcatenatingObservation, name=name, population_filter=PopulationFilter(pop_filter, include_untracked), when=when, requires_attributes=requires_attributes, results_formatter=results_formatter, to_observe=to_observe, )
@staticmethod def _check_for_required_callables( observation_name: str, required_callables: dict[str, ResultsFormatter | ResultsGatherer | ResultsUpdater], ) -> None: """Raises a ValueError if any required callable arguments are missing.""" missing = [] for arg_name, callable in required_callables.items(): if callable == _required_function_placeholder: missing.append(arg_name) if len(missing) > 0: raise ValueError( f"Observation '{observation_name}' is missing required callable(s): {missing}" )