"""
===============
Results Context
===============
"""
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy
from vivarium.engine.framework.event import Event
from vivarium.engine.framework.population import utilities as pop_utils
from vivarium.engine.framework.results.exceptions import ResultsConfigurationError
from vivarium.engine.framework.results.observation import Observation
from vivarium.engine.framework.results.stratification import (
Stratification,
get_mapped_col_name,
)
from vivarium.engine.types import ScalarMapper, VectorMapper
if TYPE_CHECKING:
from vivarium.engine.framework.engine import Builder
from vivarium.engine.framework.results.interface import PopulationFilter
[docs]
class ResultsContext:
"""Manager for organizing observations and their required stratifications.
This context object is wholly contained by :class:`ResultsManager <vivarium.engine.framework.results.manager.ResultsManager>`.
Stratifications and observations can be added to the context through the manager via the
:meth:`add_stratification <vivarium.engine.framework.results.context.ResultsContext.add_stratification>` and
:meth:`register_observation <vivarium.engine.framework.results.context.ResultsContext.register_observation>` methods, respectively.
Attributes
----------
default_stratifications
List of column names to use for stratifying results.
stratifications
List of :class:`Stratification <vivarium.engine.framework.results.stratification.Stratification>`
objects to be applied to results.
excluded_categories
Dictionary of possible per-metric stratification values to be excluded
from results processing.
observations
Dictionary of :class:`Observation <vivarium.engine.framework.results.observation.Observation>`
objects to be produced keyed by the observation name.
grouped_observations
Dictionary of observation details. It is of the format
{lifecycle_state: {PopulationFilter: {stratifications: list[Observation]}}}.
Allowable lifecycle_states are "time_step__prepare", "time_step",
"time_step__cleanup", and "collect_metrics".
logger
Logger for the results context.
"""
def __init__(self) -> None:
self.default_stratifications: list[str] = []
self.stratifications: dict[str, Stratification] = {}
self.excluded_categories: dict[str, list[str]] = {}
self.observations: dict[str, Observation] = {}
self.grouped_observations: defaultdict[
str,
defaultdict[
PopulationFilter,
defaultdict[tuple[str, ...] | None, list[Observation]],
],
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@property
def name(self) -> str:
return "results_context"
[docs]
def setup(self, builder: Builder) -> None:
"""Sets up the results context.
This method is called by the :class:`ResultsManager <vivarium.engine.framework.results.manager.ResultsManager>`
during the setup phase of that object.
"""
self.logger = builder.logging.get_logger(self.name)
self.excluded_categories = (
builder.configuration.stratification.excluded_categories.to_dict()
)
self.get_tracked_query = builder.population.get_tracked_query()
# noinspection PyAttributeOutsideInit
[docs]
def set_default_stratifications(self, default_grouping_columns: list[str]) -> None:
"""Sets the default stratifications to be used by stratified observations.
Parameters
----------
default_grouping_columns
List of stratifications to be used.
Raises
------
ResultsConfigurationError
If the `self.default_stratifications` attribute has already been set.
"""
if self.default_stratifications:
raise ResultsConfigurationError(
"Multiple calls are being made to set default grouping columns "
"for results production."
)
self.default_stratifications = default_grouping_columns
[docs]
def set_stratifications(self) -> None:
"""Sets stratifications on all Observers.
Emits a warning if any registered stratifications are not being used by any
observation.
"""
used_stratifications: set[str] = set()
for state_observations in self.grouped_observations.values():
for pop_filter_observations in state_observations.values():
for stratification_names, observations in pop_filter_observations.items():
if stratification_names is None:
continue
used_stratifications |= set(stratification_names)
for observation in observations:
observation.stratifications = tuple(
self.stratifications[name]
for name in stratification_names
if name in self.stratifications
)
if unused_stratifications := set(self.stratifications.keys()) - used_stratifications:
self.logger.info(
"The following stratifications are registered but not used by any "
f"observers: \n{sorted(list(unused_stratifications))}"
)
if missing_stratifications := used_stratifications - set(self.stratifications.keys()):
raise ValueError(
"The following stratifications are used by observers but not registered: "
f"\n{sorted(list(missing_stratifications))}"
)
[docs]
def add_stratification(
self,
name: str,
requires_attributes: list[str],
categories: list[str],
excluded_categories: list[str] | None,
mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> None:
"""Adds a stratification to the results context.
Parameters
----------
name
Name of the stratification.
requires_attributes
The population attributes needed as input for the `mapper`.
categories
Exhaustive list of all possible stratification values.
excluded_categories
List of possible stratification values to exclude from results processing.
If None (the default), will use exclusions as defined in the configuration.
mapper
A callable that maps the population attributes specified by
`requires_attributes` to the stratification categories. It can either map the entire
population or an individual simulant. A simulation will fail if the `mapper`
ever produces an invalid value.
is_vectorized
True if the `mapper` function will map the entire population, and False
if it will only map a single simulant.
Raises
------
ValueError
- If the stratification `name` is already used.
- If there are duplicate `categories`.
- If any `excluded_categories` are not in `categories`.
"""
if name in self.stratifications:
raise ValueError(f"Stratification name '{name}' is already used.")
unique_categories = set(categories)
if len(categories) != len(unique_categories):
for category in unique_categories:
categories.remove(category)
raise ValueError(
f"Found duplicate categories in stratification '{name}': {categories}."
)
# Handle excluded categories. If excluded_categories are explicitly
# passed in, we use that instead of what is in the model spec.
to_exclude = (
excluded_categories
if excluded_categories is not None
else self.excluded_categories.get(name, [])
)
unknown_exclusions = set(to_exclude) - set(categories)
if len(unknown_exclusions) > 0:
raise ValueError(
f"Excluded categories {unknown_exclusions} not found in categories "
f"{categories} for stratification '{name}'."
)
if to_exclude:
self.logger.debug(
f"'{name}' has category exclusion requests: {to_exclude}\n"
"Removing these from the allowable categories."
)
categories = [category for category in categories if category not in to_exclude]
self.stratifications[name] = Stratification(
name=name,
requires_attributes=requires_attributes,
categories=categories,
excluded_categories=to_exclude,
mapper=mapper,
is_vectorized=is_vectorized,
)
[docs]
def register_observation(
self,
observation_type: type[Observation],
name: str,
population_filter: PopulationFilter,
when: str,
requires_attributes: list[str],
stratifications: tuple[str, ...] | None,
**kwargs: Any,
) -> Observation:
"""Adds an observation to the results context.
Parameters
----------
observation_type
Specific class type of observation to register.
name
Name of the observation. It will also be the name of the output results file
for this particular observation.
population_filter
A named tuple of population filtering details. The first item is a Pandas
query string to filter the population down to the simulants who should be
considered for the observation. The second item is a boolean indicating whether
to include untracked simulants from the observation.
when
Name of the lifecycle state the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
**kwargs
Additional keyword arguments to be passed to the observation's constructor.
Returns
-------
The instantiated Observation object.
Raises
------
ValueError
If the observation `name` is already used.
"""
if name in self.observations:
raise ValueError(
f"Observation name '{name}' is already used: {self.observations[name]}."
)
observation = observation_type(
name=name,
population_filter=population_filter,
when=when,
requires_attributes=requires_attributes,
**kwargs,
)
self.observations[name] = observation
self.grouped_observations[observation.when][observation.population_filter][
stratifications
].append(observation)
return observation
[docs]
def gather_results(
self,
population: pd.DataFrame,
lifecycle_state: str,
event_observations: list[Observation],
) -> Generator[
tuple[pd.DataFrame, str, Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]],
None,
None,
]:
"""Generates and yields current results for all observations at this lifecycle state and event.
Each set of results are stratified and grouped by all registered stratifications
as well as filtered by their respective observation's pop_filter.
Parameters
----------
population
The current population DataFrame.
lifecycle_state
The current lifecycle state.
event_observations
List of observations to be gathered for this specific event. Note that this
excludes all observations whose `to_observe` method returns False.
Yields
------
A tuple containing each observation's newly observed results, the name of
the observation, and the observation's results updater function.
Observations whose filtered population has zero rows are silently
skipped (nothing is yielded for them).
Raises
------
ValueError
If a stratification's temporary column name already exists in the population DataFrame.
"""
# Optimization: We store all the producers by population_filter and stratifications
# so that we only have to apply them once each time we compute results.
for (
population_filter,
stratification_observations,
) in self.grouped_observations[lifecycle_state].items():
event_pop_filter_observations = [
observation
for observations in stratification_observations.values()
for observation in observations
if observation in event_observations
]
if not event_pop_filter_observations:
continue
filtered_population = self._filter_population(population, population_filter)
if filtered_population.index.empty:
continue
for stratification_names, observations in stratification_observations.items():
observations = [
obs for obs in observations if obs in event_pop_filter_observations
]
if not observations:
continue
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool]
pop = self._drop_na_stratifications(filtered_population, stratification_names)
if pop.index.empty:
continue
if stratification_names is not None:
pop = self._get_groups(stratification_names, pop)
for observation in observations:
results = observation.observe(pop, stratification_names)
yield (results, observation.name, observation.results_updater)
[docs]
def get_observations(self, event: Event) -> list[Observation]:
"""Gets all observations for a given event.
Parameters
----------
event
The current Event.
Returns
-------
A list of Observations for the given event. Only includes observations
whose `to_observe` method returns True.
"""
return [
observation
for stratification_observations in self.grouped_observations[event.name].values()
for observations in stratification_observations.values()
for observation in observations
if observation.to_observe(event)
]
[docs]
def get_stratifications(self, observations: list[Observation]) -> list[Stratification]:
"""Gets all stratifications for a given set of observations.
Parameters
----------
observations
The observations to gather stratifications from.
Returns
-------
A list of Stratifications used by at least one of the observations.
"""
return list(
{
stratification.name: stratification
for observation in observations
if observation.stratifications is not None
for stratification in observation.stratifications
}.values()
)
[docs]
def get_required_attributes(
self, observations: list[Observation], stratifications: list[Stratification]
) -> list[str]:
"""Gets all population attributes required for producing results for a given Event.
Parameters
----------
observations
List of observations to be gathered for this specific event. Note that this
excludes all observations whose `to_observe` method returns False.
stratifications
List of stratifications to be gathered for this specific event. This only
includes stratifications which are needed by the observations which will be
made during this `Event`.
Returns
-------
All population attributes required for producing results for the given Event.
"""
required_attributes = set()
for observation in observations:
required_attributes.update(set(observation.requires_attributes))
required_attributes.update(
pop_utils.extract_columns_from_query(self.get_tracked_query())
if not observation.population_filter.include_untracked
else set()
)
required_attributes.update(
pop_utils.extract_columns_from_query(observation.population_filter.query)
)
for stratification in stratifications:
required_attributes.update(stratification.requires_attributes)
return list(required_attributes)
def _filter_population(
self, population: pd.DataFrame, population_filter: PopulationFilter
) -> pd.DataFrame:
"""Filter out simulants not to observe."""
query = population_filter.query
if not population_filter.include_untracked:
# combine the tracking query with the population filter query
query = pop_utils.combine_queries(query, self.get_tracked_query())
return population.query(query) if query else population.copy()
def _drop_na_stratifications(
self, population: pd.DataFrame, stratification_names: tuple[str, ...] | None
) -> pd.DataFrame:
"""Filter out simulants not to observe."""
if stratification_names:
# Drop all rows in the mapped_stratification columns that have NaN values
# (which only exist if the mapper returned an excluded category).
population = population.dropna(
subset=[
get_mapped_col_name(stratification)
for stratification in stratification_names
]
)
return population
@staticmethod
def _get_groups(
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy[tuple[str, ...] | str, bool]:
"""Groups the population by stratification.
Notes
-----
Stratifications at this point can be an empty tuple.
HACK: If there are no `stratifications` (i.e. it's an empty tuple), we
create a single group of the entire `filtered_pop` index and assign
it a name of "all". The alternative is to use the entire population
instead of a groupby object, but then we would need to handle
the different ways the aggregator can behave.
"""
if stratifications:
pop_groups = filtered_pop.groupby(
[get_mapped_col_name(stratification) for stratification in stratifications],
observed=False,
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups # type: ignore[return-value]