Source code for vivarium.engine.framework.results.stratification

"""
===============
Stratifications
===============

"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.engine.types import ScalarMapper, VectorMapper

STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"

# TODO: Parameterizing pandas objects fails below python 3.12


[docs] @dataclass class Stratification: """Class for stratifying observed quantities by specified characteristics. Each Stratification represents a set of mutually exclusive and collectively exhaustive categories into which simulants can be assigned. This class includes a :meth:`stratify <stratify>` method that produces an output column by calling the mapper on the source columns. """ name: str """Name of the stratification.""" requires_attributes: list[str] """The population attributes needed as input for the `mapper`.""" categories: list[str] """Exhaustive list of all possible stratification values.""" excluded_categories: list[str] """List of possible stratification values to exclude from results processing. If None (the default), will use exclusions as defined in the configuration.""" mapper: VectorMapper | ScalarMapper | None """A callable that maps the population attributes specified by the `requires_attributes` argument to the stratification categories. It can either map the entire population or an individual simulant. A simulation will fail if the `mapper` ever produces an invalid value.""" is_vectorized: bool = False """True if the `mapper` function will map the entire population, and False if it will only map a single simulant.""" def __str__(self) -> str: return ( f"Stratification '{self.name}' with required attributes {self.requires_attributes}, " f"categories {self.categories}, and mapper {getattr(self.mapper, '__name__', repr(self.mapper))}" ) def __post_init__(self) -> None: """Assigns a default `mapper` if none was provided and check for non-empty `categories` and `requires_attributes` otherwise. Raises ------ ValueError If no mapper is provided and the number of sources is not 1. ValueError If the categories argument is empty. ValueError If the requires_attributes argument is empty. """ self.vectorized_mapper = self._get_vectorized_mapper(self.mapper, self.is_vectorized) if not self.categories: raise ValueError("The categories argument must be non-empty.") if not self.requires_attributes: raise ValueError("The requires_attributes argument must be non-empty.")
[docs] def stratify(self, population: pd.DataFrame) -> pd.Series[CategoricalDtype]: """Applies the `mapper` to the population `sources` columns. This creates a new Series to be added to the population. Any `excluded_categories` (which have already been removed from `categories`) will be converted to NaNs in the new column and dropped later at the observation level. Parameters ---------- population A DataFrame containing the data to be stratified. Returns ------- A Series containing the mapped values to be used for stratifying. Raises ------ ValueError If the mapper returns any values not in `categories` or `excluded_categories`. """ mapped_column = self.vectorized_mapper(population[self.requires_attributes]) unknown_categories = set(mapped_column) - set( self.categories + self.excluded_categories ) # Reduce all nans to a single one unknown_categories = {cat for cat in unknown_categories if not pd.isna(cat)} if mapped_column.isna().any(): unknown_categories.add(mapped_column[mapped_column.isna()].iat[0]) if unknown_categories: raise ValueError(f"Invalid values mapped to {self.name}: {unknown_categories}") # Convert the dtype to the allowed categories. Note that this will # result in Nans for any values in excluded_categories. return mapped_column.astype( CategoricalDtype(categories=self.categories, ordered=True) )
def _get_vectorized_mapper( self, user_provided_mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> VectorMapper: """Chooses a VectorMapper based on the provided callable mapper.""" if user_provided_mapper is None: if len(self.requires_attributes) != 1: raise ValueError( f"No mapper but {len(self.requires_attributes)} required attributes are " f"provided for stratification {self.name}. The list of required attributes " "must be of length 1 if no mapper is provided." ) return self._default_mapper elif is_vectorized: return user_provided_mapper # type: ignore [return-value] else: return lambda population: population.apply(user_provided_mapper, axis=1) @staticmethod def _default_mapper(pop: pd.DataFrame) -> pd.Series[Any]: """Squeezes a DataFrame to a Series. Parameters ---------- pop The data to be stratified. Returns ------- The squeezed data to be stratified. Notes ----- The input DataFrame is guaranteed to have a single column. """ squeezed_pop: pd.Series[Any] = pop.squeeze(axis=1) return squeezed_pop
[docs] def get_mapped_col_name(col_name: str) -> str: """Returns a new column name to be used for mapped values""" return f"{col_name}_{STRATIFICATION_COLUMN_SUFFIX}"
[docs] def get_original_col_name(col_name: str) -> str: """Returns the original column name given a modified mapped column name.""" return ( col_name[: -(len(STRATIFICATION_COLUMN_SUFFIX)) - 1] if col_name.endswith(f"_{STRATIFICATION_COLUMN_SUFFIX}") else col_name )