Source code for vivarium.engine.framework.randomness.manager

"""
==================
Randomness Manager
==================

"""

from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Literal

import pandas as pd

from vivarium.engine.framework.lifecycle import lifecycle_states
from vivarium.engine.framework.randomness.exceptions import RandomnessError
from vivarium.engine.framework.randomness.index_map import IndexMap
from vivarium.engine.framework.randomness.stream import RandomnessStream, get_hash
from vivarium.engine.manager import Manager
from vivarium.engine.types import ClockTime

if TYPE_CHECKING:
    from vivarium.engine.component import Component
    from vivarium.engine.framework.engine import Builder
    from vivarium.engine.framework.resource import Resource


[docs] class RandomnessManager(Manager): """Access point for common random number generation.""" CONFIGURATION_DEFAULTS = { "randomness": { "map_size": 1_000_000, "key_columns": [], "random_seed": 0, "additional_seed": None, "rate_conversion_type": "linear", } } def __init__(self) -> None: self._seed: str = "" self._clock_: Callable[[], ClockTime] | None = None self._key_columns: list[str] = [] self._key_mapping_: IndexMap | None = None self._decision_points: dict[str, RandomnessStream] = dict() self._rate_conversion_type: Literal["linear", "exponential"] = "linear" @property def name(self) -> str: return "randomness_manager" @property def _clock(self) -> Callable[[], ClockTime]: if self._clock_ is None: raise RandomnessError("RandomnessManager clock was invoked before being set.") return self._clock_ @property def _key_mapping(self) -> IndexMap: if self._key_mapping_ is None: raise RandomnessError( "RandomnessManager key_mapping was invoked before being set." ) return self._key_mapping_
[docs] def setup(self, builder: Builder) -> None: if builder.configuration.randomness.additional_seed is not None: additional_seed = builder.configuration.randomness.additional_seed else: additional_seed = builder.configuration.input_data.input_draw_number self._seed = f"{builder.configuration.randomness.random_seed}_{additional_seed}" self._clock_ = builder.time.clock() self._key_columns = builder.configuration.randomness.key_columns map_size = builder.configuration.randomness.map_size pop_size = builder.configuration.population.population_size map_size = max(map_size, 10 * pop_size) self._key_mapping_ = IndexMap(self._key_columns, map_size) self._get_current_component = builder.components.get_current_component self._rate_conversion_type = builder.configuration.randomness.rate_conversion_type self._add_constraint = builder.lifecycle.add_constraint self._add_resource = builder.resources.add_resource self._add_constraint(self.get_seed, restrict_during=[lifecycle_states.INITIALIZATION]) self._add_constraint( self.get_randomness_stream, allow_during=[lifecycle_states.SETUP] ) self._add_constraint( self.register_simulants, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, lifecycle_states.POST_SETUP, lifecycle_states.SIMULATION_END, lifecycle_states.REPORT, ], )
[docs] def get_randomness_stream( self, decision_point: str, initializes_crn_attributes: bool = False, rate_conversion_type: Literal["linear", "exponential"] = "linear", ) -> RandomnessStream: """Provides a new source of random numbers for the given decision point. Parameters ---------- decision_point A unique identifier for a stream of random numbers. Typically represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. initializes_crn_attributes A flag indicating whether this stream is used to generate key initialization information that will be used to identify simulants in the Common Random Number framework. These streams cannot be copied and should only be used to generate the state table columns specified in ``builder.configuration.randomness.key_columns``. rate_conversion_type The type of conversion to use. Default is "linear" for a simple multiplication of rate and time_scaling_factor. The other option is "exponential". Returns ------- An entry point into the Common Random Number framework. The stream provides vectorized access to random numbers and a few other utilities. Raises ------ RandomnessError If another location in the simulation has already created a randomness stream with the same identifier. """ stream = self._get_randomness_stream( decision_point, self._get_current_component(), initializes_crn_attributes, rate_conversion_type, self._key_columns if not initializes_crn_attributes else [], ) self._add_resource(stream) self._add_constraint( stream.get_draw, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, lifecycle_states.POST_SETUP, ], ) self._add_constraint( stream.filter_for_probability, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, lifecycle_states.POST_SETUP, ], ) self._add_constraint( stream.filter_for_rate, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, lifecycle_states.POST_SETUP, ], ) self._add_constraint( stream.choice, restrict_during=[ lifecycle_states.INITIALIZATION, lifecycle_states.SETUP, lifecycle_states.POST_SETUP, ], ) return stream
def _get_randomness_stream( self, decision_point: str, component: Component, initializes_crn_attributes: bool = False, rate_conversion_type: Literal["linear", "exponential"] = "linear", required_resources: Iterable[str | Resource] = (), ) -> RandomnessStream: if decision_point in self._decision_points: raise RandomnessError( f"Two separate places are attempting to create " f"the same randomness stream for {decision_point}" ) stream = RandomnessStream( key=decision_point, clock=self._clock, seed=self._seed, index_map=self._key_mapping, component=component, initializes_crn_attributes=initializes_crn_attributes, rate_conversion_type=rate_conversion_type, required_resources=required_resources, ) self._decision_points[decision_point] = stream return stream
[docs] def get_seed(self, decision_point: str) -> int: """Gets a randomly generated seed for use with external randomness tools. Parameters ---------- decision_point A unique identifier for a stream of random numbers. Typically represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. Returns ------- A seed for a random number generation that is linked to Vivarium's Common Random Number framework. """ return get_hash("_".join([decision_point, str(self._clock()), str(self._seed)]))
[docs] def register_simulants(self, simulants: pd.DataFrame) -> None: """Adds new simulants to the randomness mapping. Parameters ---------- simulants A table with state data representing the new simulants. Each simulant should pass through this function exactly once. Raises ------ RandomnessError If the provided table does not contain all key columns specified in the configuration. """ if not all(k in simulants.columns for k in self._key_columns): raise RandomnessError( "The simulants dataframe does not have all specified key_columns." ) self._key_mapping.update(simulants.loc[:, self._key_columns], self._clock())
def __str__(self) -> str: return "RandomnessManager()" def __repr__(self) -> str: return f"RandomnessManager(seed={self._seed}, key_columns={self._key_columns})"