"""
==========================
Vivarium Testing Utilities
==========================
Utility functions and classes to make testing ``vivarium`` components easier.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from datetime import datetime, timedelta
from itertools import product
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from vivarium.engine import Component
from vivarium.engine.framework.engine import Builder
from vivarium.engine.framework.event import Event
from vivarium.engine.framework.population import SimulantData
from vivarium.engine.framework.randomness.index_map import IndexMap
from vivarium.engine.framework.randomness.stream import RandomnessStream
from vivarium.engine.types import ClockStepSize, ClockTime
[docs]
class NonCRNTestPopulation(Component):
CONFIGURATION_DEFAULTS = {
"population": {
"initialization_age_min": 0,
"initialization_age_max": 100,
"untracking_age": None,
},
}
[docs]
def setup(self, builder: Builder) -> None:
self.config = builder.configuration
self.randomness = builder.randomness.get_stream(
"population_age_fuzz", initializes_crn_attributes=True
)
builder.population.register_initializer(
initializer=self.initialize_population,
columns=["age", "sex", "location", "is_alive", "entrance_time", "exit_time"],
required_resources=[self.randomness],
)
[docs]
def initialize_population(self, pop_data: SimulantData) -> None:
age_start = pop_data.user_data.get(
"age_start", self.config.population.initialization_age_min
)
age_end = pop_data.user_data.get(
"age_end", self.config.population.initialization_age_max
)
location = self.config.input_data.location
population = _non_crn_build_population(
pop_data.index,
age_start,
age_end,
location,
pop_data.creation_time,
pop_data.creation_window,
self.randomness,
)
self.population_view.initialize(population)
[docs]
def on_time_step(self, event: Event) -> None:
living_index = self.population_view.get_filtered_index(
event.index, query="is_alive == True"
)
# This component won't work if event.step_size is an int
if not isinstance(event.step_size, int):
delta = event.step_size / pd.Timedelta(days=365)
self.population_view.update(
"age",
lambda age: age.loc[living_index] + delta,
)
[docs]
class TestPopulation(NonCRNTestPopulation):
[docs]
def setup(self, builder: Builder) -> None:
self.config = builder.configuration
self.randomness = builder.randomness.get_stream(
"population_age_fuzz", initializes_crn_attributes=True
)
self.age_randomness = builder.randomness.get_stream(
"age_initialization", initializes_crn_attributes=True
)
self.register = builder.randomness.register_simulants
builder.population.register_initializer(
initializer=self.initialize_population,
columns=["age", "sex", "location", "is_alive", "entrance_time", "exit_time"],
required_resources=[self.randomness, self.age_randomness],
)
[docs]
def initialize_population(self, pop_data: SimulantData) -> None:
age_start = pop_data.user_data.get(
"age_start", self.config.population.initialization_age_min
)
age_end = pop_data.user_data.get(
"age_end", self.config.population.initialization_age_max
)
age_draw = self.age_randomness.get_draw(pop_data.index)
if age_start == age_end:
# This component won't work if creation window is an int
if not isinstance(pop_data.creation_window, int):
age = (
age_draw * (pop_data.creation_window / pd.Timedelta(days=365)) + age_start
)
else:
age = age_draw * (age_end - age_start) + age_start
core_population = pd.DataFrame(
{"entrance_time": pop_data.creation_time, "age": age.values}, index=pop_data.index
)
self.register(core_population)
if "location" in self.config.input_data.keys():
location = self.config.input_data.location
else:
location = self.randomness.choice(
pop_data.index, ["USA", "Canada", "Mexico"], additional_key="location_choice"
)
population = _build_population(core_population, location, self.randomness)
self.population_view.initialize(population)
def _build_population(
core_population: pd.DataFrame, location: str, randomness_stream: RandomnessStream
) -> pd.DataFrame:
index = core_population.index
population = pd.DataFrame(
{
"age": core_population["age"],
"entrance_time": core_population["entrance_time"],
"sex": randomness_stream.choice(
index, ["Male", "Female"], additional_key="sex_choice"
),
"is_alive": pd.Series(True, index=index),
"location": location,
"exit_time": pd.NaT,
},
index=index,
)
return population
def _non_crn_build_population(
index: pd.Index[int],
age_start: float,
age_end: float,
location: str,
creation_time: ClockTime,
creation_window: ClockStepSize,
randomness_stream: RandomnessStream,
) -> pd.DataFrame:
if age_start == age_end:
if not isinstance(creation_window, int):
age = (
randomness_stream.get_draw(index) * (creation_window / pd.Timedelta(days=365))
+ age_start
)
else:
age = randomness_stream.get_draw(index) * (age_end - age_start) + age_start
population = pd.DataFrame(
{
"age": age,
"sex": randomness_stream.choice(
index, ["Male", "Female"], additional_key="sex_choice"
),
"is_alive": pd.Series(True, index=index),
"location": location,
"entrance_time": creation_time,
"exit_time": pd.NaT,
},
index=index,
)
return population
[docs]
def build_table(
value: Any,
parameter_columns: dict[str, Sequence[int]] = {
"age": (0, 125),
"year": (1990, 2020),
},
key_columns: dict[str, Sequence[Any]] = {"sex": ("Female", "Male")},
value_columns: list[str] = ["value"],
) -> pd.DataFrame:
"""
Parameters
----------
value
Value(s) to put in the value columns of a lookup table.
parameter_columns
A dictionary where the keys are parameter (continuous) columns of a lookup table
and the values are tuple of the range (inclusive) for that column.
key_columns
A dictionary where the keys are key (categorical) columns of a lookup table
and the values are a tuple of the categories for that column
value_columns
A list of value columns that will appear in the returned lookup table
Returns
-------
A pandas dataframe that has the cartesian product of the range of all parameter columns
and the values of the key columns.
"""
if not isinstance(value, list):
value = [value] * len(value_columns)
if len(value) != len(value_columns):
raise ValueError("Number of values must match number of value columns")
# Get product of parameter columns
range_parameter_product = {
key: list(range(value[0], value[1])) for key, value in parameter_columns.items()
}
# Build out dict of items we will need cartesian product of to make dataframe
product_dict = dict(range_parameter_product)
product_dict.update(key_columns) # type: ignore [arg-type]
products = product(*product_dict.values())
rows = []
for item in products:
# Note: item is going to be a tuple of the cartesian product of the key column values and parameter column
# values and will be ordered in the order of the parameter then key dict keys
r_values = []
for val in value:
if val is None:
r_values.append(np.random.random())
elif callable(val):
r_values.append(val(item))
else:
r_values.append(val)
# Get list of values for rows (index values)
key_columns_index_values = list(item[len(parameter_columns) :])
# Transform parameter column values
parameter_columns_index_values = item[: len(parameter_columns)]
# Create intervals for parameter columns. Example year, year+1 for year_start and year_end
unpacked_parameter_columns_index_values: list[Any] = [
v for val in parameter_columns_index_values for v in (val, val + 1)
]
rows.append(
unpacked_parameter_columns_index_values + key_columns_index_values + r_values
)
# Make list of parameter column names
parameter_column_names = [
col_name for col in parameter_columns for col_name in (f"{col}_start", f"{col}_end")
]
return pd.DataFrame(
rows,
columns=parameter_column_names + list(key_columns.keys()) + value_columns,
)
[docs]
def get_randomness(
key: str = "test",
clock: Callable[[], pd.Timestamp | datetime | int] = lambda: pd.Timestamp(1990, 7, 2),
seed: int = 12345,
initializes_crn_attributes: bool = False,
component: Component | None = None,
) -> RandomnessStream:
if component is None:
# Create a simple mock component for testing
class _MockComponent(Component):
@property
def name(self) -> str:
return "mock_component"
component = _MockComponent()
return RandomnessStream(
key,
clock,
seed=seed,
index_map=IndexMap(),
component=component,
initializes_crn_attributes=initializes_crn_attributes,
)