"""
===============
Results Manager
===============
"""
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Sequence
import pandas as pd
from vivarium.engine.framework.event import Event
from vivarium.engine.framework.lifecycle import lifecycle_states
from vivarium.engine.framework.results.context import ResultsContext
from vivarium.engine.framework.results.observation import Observation
from vivarium.engine.framework.results.stratification import (
Stratification,
get_mapped_col_name,
)
from vivarium.engine.manager import Manager
from vivarium.engine.types import ScalarMapper, VectorMapper
if TYPE_CHECKING:
from vivarium.engine.framework.engine import Builder
from vivarium.engine.framework.results.interface import PopulationFilter
[docs]
class ResultsManager(Manager):
"""Backend manager object for the results management system.
This class contains the public methods used by the :class:`ResultsInterface <vivarium.engine.framework.results.interface.ResultsInterface>`
to register stratifications and observations as well as the :meth:`get_results <get_results>`
method used to retrieve formatted results by the :class:`ResultsContext <vivarium.engine.framework.results.context.ResultsContext>`.
"""
CONFIGURATION_DEFAULTS = {
"stratification": {
"default": [],
"excluded_categories": {},
}
}
def __init__(self) -> None:
self._raw_results: defaultdict[str, pd.DataFrame] = defaultdict()
self._results_context = ResultsContext()
self._name = "results_manager"
@property
def name(self) -> str:
return self._name
[docs]
def get_results(self) -> dict[str, pd.DataFrame]:
"""Gets the measure-specific formatted results in a dictionary.
Returns
-------
A dictionary of measure-specific formatted results. The keys are the
measure names and the values are the respective results.
"""
formatted = {}
for name, observation in self._results_context.observations.items():
results = self._raw_results[name].copy()
formatted[name] = observation.results_formatter(name, results)
return formatted
# noinspection PyAttributeOutsideInit
[docs]
def setup(self, builder: "Builder") -> None:
"""Sets up the results manager."""
self._results_context.setup(builder)
self.logger = builder.logging.get_logger(self.name)
self.population_view = builder.population.get_view()
self.clock = builder.time.clock()
self.step_size = builder.time.step_size()
builder.event.register_listener(lifecycle_states.POST_SETUP, self.on_post_setup)
builder.event.register_listener(
lifecycle_states.TIME_STEP_PREPARE, self.on_time_step_prepare
)
builder.event.register_listener(lifecycle_states.TIME_STEP, self.on_time_step)
builder.event.register_listener(
lifecycle_states.TIME_STEP_CLEANUP, self.on_time_step_cleanup
)
builder.event.register_listener(
lifecycle_states.COLLECT_METRICS, self.on_collect_metrics
)
self.set_default_stratifications(builder)
[docs]
def on_post_setup(self, _: Event) -> None:
"""Sets stratifications on observations and initializes results for each measure."""
self._results_context.set_stratifications()
for name, observation in self._results_context.observations.items():
self._raw_results[name] = observation.results_initializer()
[docs]
def on_time_step_prepare(self, event: Event) -> None:
"""Defines the listener callable for the time_step__prepare phase."""
self.gather_results(event)
[docs]
def on_time_step(self, event: Event) -> None:
"""Defines the listener callable for the time_step phase."""
self.gather_results(event)
[docs]
def on_time_step_cleanup(self, event: Event) -> None:
"""Defines the listener callable for the time_step__cleanup phase."""
self.gather_results(event)
[docs]
def on_collect_metrics(self, event: Event) -> None:
"""Defines the listener callable for the collect_metrics phase."""
self.gather_results(event)
[docs]
def gather_results(self, event: Event) -> None:
"""Updates existing results with any new results."""
observations = self._results_context.get_observations(event)
stratifications = self._results_context.get_stratifications(observations)
if not observations or event.index.empty:
return
population = self._prepare_population(event, observations, stratifications)
for results_group, measure, updater in self._results_context.gather_results(
population, event.name, observations
):
self._raw_results[measure] = updater(self._raw_results[measure], results_group)
##########################
# Stratification methods #
##########################
[docs]
def set_default_stratifications(self, builder: "Builder") -> None:
"""Sets the default stratifications for the results context.
This passes the default stratifications from the configuration to the
:class:`ResultsContext <vivarium.engine.framework.results.context.ResultsContext>`
:meth:`set_default_stratifications` method to be set.
Parameters
----------
builder
The builder object for the simulation.
"""
default_stratifications = builder.configuration.stratification.default
self._results_context.set_default_stratifications(default_stratifications)
[docs]
def register_stratification(
self,
name: str,
categories: list[str],
excluded_categories: list[str] | None,
mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
requires_attributes: list[str] = [],
) -> None:
"""Registers a stratification that can be used by stratified observations.
Adds a stratification to the :class:`ResultsContext <vivarium.engine.framework.results.context.ResultsContext>`
as well as the stratification's required resources to this manager.
Parameters
----------
name
Name of the stratification.
categories
Exhaustive list of all possible stratification values.
excluded_categories
List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration.
mapper
A callable that maps population attributes specified by the
`requires_attributes` argument to the stratification categories. It can
either map the entire population or an individual simulant. A simulation
will fail if the `mapper` ever produces an invalid value.
is_vectorized
True if the `mapper` function will map the entire population, and False
if it will only map a single simulant.
requires_attributes
A list of the state table columns that are required by the `mapper`
to produce the stratification.
"""
self.logger.debug(f"Registering stratification {name}")
self._results_context.add_stratification(
name=name,
requires_attributes=requires_attributes,
categories=categories,
excluded_categories=excluded_categories,
mapper=mapper,
is_vectorized=is_vectorized,
)
[docs]
def register_binned_stratification(
self,
target: str,
binned_column: str,
bin_edges: Sequence[int | float],
labels: list[str],
excluded_categories: list[str] | None,
**cut_kwargs: int | str | bool,
) -> None:
"""Registers a continuous `target` quantity to observe into bins in a `binned_column`.
Parameters
----------
target
Name of population attribute to be binned.
binned_column
Name of the (binned) stratification.
bin_edges
List of scalars defining the bin edges, passed to :meth: pandas.cut.
The length must be equal to the length of `labels` plus 1.
labels
List of string labels for bins. The length must be equal to the length
of `bin_edges` minus 1.
excluded_categories
List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration.
**cut_kwargs
Keyword arguments for :meth: pandas.cut.
"""
if not isinstance(labels, list) or not all(
[isinstance(label, str) for label in labels]
):
raise ValueError(
f"Labels must be a list of strings when registering a binned stratification, but labels was {labels} when registering the {binned_column} stratification."
)
def _bin_data(data: pd.DataFrame) -> pd.Series[Any]:
"""Use pandas.cut to bin continuous values"""
data = data.squeeze(axis=1)
if not isinstance(data, pd.Series):
raise ValueError(f"Expected a Series, but got type {type(data)}.")
data = pd.cut(
data, bin_edges, labels=labels, right=False, include_lowest=True, **cut_kwargs
)
return data
if len(bin_edges) != len(labels) + 1:
raise ValueError(
f"The number of bin edges plus 1 ({len(bin_edges)+1}) does not "
f"match the number of labels ({len(labels)})"
)
self.register_stratification(
name=binned_column,
categories=labels,
excluded_categories=excluded_categories,
mapper=_bin_data,
is_vectorized=True,
requires_attributes=[target],
)
[docs]
def register_observation(
self,
observation_type: type[Observation],
name: str,
population_filter: PopulationFilter,
when: str,
requires_attributes: list[str],
**kwargs: Any,
) -> None:
"""Registers an observation to the results system.
Parameters
----------
observation_type
Specific class type of observation to register.
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
population_filter
A named tuple of population filtering details. The first item is a Pandas
query string to filter the population down to the simulants who should be
considered for the observation. The second item is a boolean indicating whether
to include untracked simulants from the observation.
when
Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
requires_attributes
The population attributes that are required to compute the observation.
**kwargs
Additional keyword arguments to be passed to the observation's constructor.
"""
self.logger.debug(f"Registering observation {name}")
if any(not isinstance(attribute, str) for attribute in requires_attributes):
raise TypeError(
f"All required attributes must be strings, but got {requires_attributes} when registering observation {name}."
)
if observation_type.is_stratified():
stratifications = self._get_stratifications(
list(kwargs.get("stratifications", [])),
list(kwargs.get("additional_stratifications", [])),
list(kwargs.get("excluded_stratifications", [])),
)
# Remove the unused kwargs before passing to the results context registration
del kwargs["additional_stratifications"]
del kwargs["excluded_stratifications"]
else:
stratifications = None
self._results_context.register_observation(
observation_type=observation_type,
name=name,
population_filter=population_filter,
when=when,
requires_attributes=requires_attributes,
stratifications=stratifications,
**kwargs,
)
##################
# Helper methods #
##################
def _get_stratifications(
self,
stratifications: list[str] = [],
additional_stratifications: list[str] = [],
excluded_stratifications: list[str] = [],
) -> tuple[str, ...]:
"""Resolves the stratifications required for the observation."""
self._warn_check_stratifications(additional_stratifications, excluded_stratifications)
stratifications = list(
set(
self._results_context.default_stratifications
+ stratifications
+ additional_stratifications
)
- set(excluded_stratifications)
)
# Makes sure measure identifiers have fields in the same relative order.
return tuple(sorted(stratifications))
def _prepare_population(
self,
event: Event,
observations: list[Observation],
stratifications: list[Stratification],
) -> pd.DataFrame:
"""Prepares the population for results gathering."""
required_attributes = self._results_context.get_required_attributes(
observations, stratifications
)
attributes_to_get = [
attribute
for attribute in required_attributes
if attribute
not in ["current_time", "event_step_size", "event_time"]
+ list(event.user_data.keys())
]
if attributes_to_get:
# FIXME: (Inefficiency) In the event every single observation has some identical
# query string (e.g. 'is_alive == True'), we still calculate all attributes for
# the entire population and then apply the query downstream.
population = self.population_view.get(
event.index,
attributes_to_get,
include_untracked=any(
obs.population_filter.include_untracked for obs in observations
),
)
else:
population = pd.DataFrame(index=event.index)
if "current_time" in required_attributes:
population["current_time"] = self.clock()
if "event_step_size" in required_attributes:
population["event_step_size"] = event.step_size
if "event_time" in required_attributes:
population["event_time"] = self.clock() + event.step_size # type: ignore [operator]
for key, val in event.user_data.items():
if key in required_attributes:
population[key] = val
for stratification in stratifications:
new_column = get_mapped_col_name(stratification.name)
if new_column in population.columns:
raise ValueError(
f"Stratification column '{new_column}' already exists in the state table or "
"as a pipeline which is a required name for stratifying results - choose a "
"different name."
)
population[new_column] = stratification.stratify(population)
return population
def _warn_check_stratifications(
self, additional_stratifications: list[str], excluded_stratifications: list[str]
) -> None:
"""Checks additional and excluded stratifications if they'd not affect
stratifications (i.e., would be NOP), and emit warning."""
nop_additional = [
s
for s in additional_stratifications
if s in self._results_context.default_stratifications
]
if len(nop_additional):
self.logger.warning(
f"Specified additional stratifications are already included by default: {nop_additional}",
)
nop_exclude = [
s
for s in excluded_stratifications
if s not in self._results_context.default_stratifications
]
if len(nop_exclude):
self.logger.warning(
f"Specified excluded stratifications are already not included by default: {nop_exclude}",
)