"""
==============================
Branch and Keyspace Management
==============================
Tools for managing the parameter space of a parallel run.
"""
from __future__ import annotations
from collections.abc import Iterator
from itertools import product
from pathlib import Path
from typing import Any
import numpy as np
import yaml
from vivarium.engine.framework.utilities import collapse_nested_dict
from vivarium.cluster_tools.psimulate.model_specification import FULL_ARTIFACT_PATH_KEY
NUMBER_OF_DRAWS = 250
[docs]
class Keyspace:
"""A representation of a collection of simulation configurations."""
def __init__(self, branches: list[dict[str, Any]], keyspace: dict[str, Any]):
self.branches = branches
self._keyspace = keyspace
[docs]
@classmethod
def from_branch_configuration(cls, branch_configuration_file: str | Path) -> Keyspace:
"""
Parameters
----------
branch_configuration_file
Absolute path to the branch configuration file.
"""
(
branches,
input_draw_count,
random_seed_count,
input_draws,
random_seeds,
) = load_branch_configuration(Path(branch_configuration_file))
keyspace = calculate_keyspace(branches)
keyspace["input_draw"] = (
input_draws if input_draws else calculate_input_draws(input_draw_count)
)
keyspace["random_seed"] = (
random_seeds if random_seeds else calculate_random_seeds(random_seed_count)
)
return Keyspace(branches, keyspace)
[docs]
@classmethod
def from_previous_run(cls, keyspace_path: Path, branches_path: Path) -> Keyspace:
keyspace = yaml.full_load(keyspace_path.read_text())
branches = yaml.full_load(branches_path.read_text())
return Keyspace(branches, keyspace)
[docs]
@classmethod
def for_load_test(cls, num_workers: int) -> Keyspace:
"""Create a keyspace for load testing.
Parameters
----------
num_workers
The number of workers (and thus jobs) to create.
Returns
-------
A Keyspace with the specified number of unique random seeds and input draws
and an empty branch configuration.
"""
return cls(
branches=[{}],
keyspace={
"input_draw": [0],
"random_seed": list(range(num_workers)),
},
)
[docs]
@classmethod
def from_entry_point_args(
cls,
input_branch_configuration_path: Path | None,
keyspace_path: Path,
branches_path: Path,
extras: dict[str, Any],
) -> "Keyspace":
if input_branch_configuration_path is not None:
keyspace = cls.from_branch_configuration(
input_branch_configuration_path,
)
elif keyspace_path.exists():
keyspace = cls.from_previous_run(keyspace_path, branches_path)
keyspace.add_draws(extras.get("num_draws", 0))
keyspace.add_seeds(extras.get("num_seeds", 0))
else:
keyspace = Keyspace([], {})
return keyspace
[docs]
def persist(self, keyspace_path: Path, branches_path: Path) -> None:
keyspace_path.write_text(yaml.dump(self._keyspace))
branches_path.write_text(yaml.dump(self.branches))
[docs]
def add_draws(self, num_draws: int) -> None:
existing = self._keyspace["input_draw"]
additional = calculate_input_draws(num_draws, existing)
self._keyspace["input_draw"] = existing + additional
[docs]
def add_seeds(self, num_seeds: int) -> None:
existing = self._keyspace["random_seed"]
additional = calculate_random_seeds(num_seeds, existing)
self._keyspace["random_seed"] = existing + additional
def __contains__(self, item: str) -> bool:
"""Checks whether the item is present in the Keyspace"""
return item in self._keyspace
def __iter__(self) -> Iterator[tuple[int, int, dict[str, Any]]]:
"""Yields and individual simulation configuration from the keyspace."""
for job_config in product(
self._keyspace["input_draw"], self._keyspace["random_seed"], self.branches
):
if job_config[2] is None:
job_config[2] = {}
yield job_config
def __len__(self) -> int:
"""Returns the number of individual simulation runs this keyspace represents."""
return len(
list(
product(
self._keyspace["input_draw"], self._keyspace["random_seed"], self.branches
)
)
)
[docs]
def calculate_random_seeds(
random_seed_count: int, existing_seeds: list[int] | None = None
) -> list[int]:
"""Generates random seeds to use given a count of seeds and any existing seeds.
Parameters
----------
random_seed_count
The number of random seeds to generate.
existing_seeds
Any random seeds that have already been generated and should not be
generated again.
Returns
-------
A set of unique random seeds, guaranteed not to overlap with any
existing random seeds.
"""
if not random_seed_count:
return []
max_seed_count = 10000
if random_seed_count > max_seed_count:
raise ValueError(f"Random seed count must be less than {max_seed_count}.")
possible = list(range(max_seed_count))
if existing_seeds:
possible = sorted(list(set(possible).difference(existing_seeds)))
np.random.seed(654321)
np.random.shuffle(possible)
return possible[:random_seed_count]
[docs]
def calculate_keyspace(branches: list[dict[str, Any]]) -> dict[str, list[Any]]:
tmp_keyspace: dict[str, set[Any]] = {k: {v} for k, v in collapse_nested_dict(branches[0])}
for branch in branches[1:]:
branch = dict(collapse_nested_dict(branch))
if set(branch.keys()) != set(tmp_keyspace.keys()):
raise ValueError("All branches must have the same keys")
for k, v in branch.items():
if k == FULL_ARTIFACT_PATH_KEY:
validate_artifact_path(v)
tmp_keyspace[k].add(v)
keyspace: dict[str, list[Any]] = {k: list(v) for k, v in tmp_keyspace.items()}
return keyspace
[docs]
def load_branch_configuration(
path: Path,
) -> tuple[list[dict[str, Any]], int, int, list[int] | None, list[int] | None]:
data = yaml.full_load(path.read_text())
input_draw_count = data.get("input_draw_count", 1)
random_seed_count = data.get("random_seed_count", 1)
input_draws = data.get("input_draws", None)
random_seeds = data.get("random_seeds", None)
# Validate configuration of counts and values for input_draws and random_seeds
_check_count_and_values(
data, input_draw_count, input_draws, "input_draw_count", "input_draws", 1000
)
_check_count_and_values(
data,
random_seed_count,
random_seeds,
"random_seed_count",
"random_seeds",
10000,
)
if "branches" in data:
branches = expand_branch_templates(data["branches"])
else:
branches = [{}]
return branches, input_draw_count, random_seed_count, input_draws, random_seeds
def _check_count_and_values(
configuration: dict[str, Any],
value_count: int,
values: list[int],
count_name: str,
values_name: str,
max_count: int,
) -> None:
"""Checks input configuration count and values for integers outside of range.
Parameters
----------
configuration
Dictionary of the configuration data.
value_count
Integer for the number of values provided.
values
List of integer values.
count_name
Configuration key string for value count.
values_name
Configuration key string for values list.
max_count
Integer for the maximum number of values, maximum value is max_count - 1.
"""
if count_name in configuration and values_name in configuration:
if len(values) != value_count:
raise ValueError(
f"Both {count_name} and {values_name} are defined but they are inconsistent. "
f"{count_name} is {value_count} while {values_name} has length {len(values)}. "
)
if values:
if [d for d in values if d not in range(0, max_count)]:
raise ValueError(
f"{values_name} contains draws outside of 0-{max_count - 1}: "
f"{[d for d in values if d not in range(0, max_count)]}"
)
if value_count < 1 or value_count > max_count:
raise ValueError(f"{count_name} must be within 1-{max_count}. Given: {value_count}")
[docs]
def expand_branch_templates(templates: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Expand branch template lists into individual branches.
Take a list of dictionaries of configuration values (like the ones used in
experiment branch configurations) and expand it by taking any values which
are lists and creating a new set of branches which is made up of the
product of all those lists plus all non-list values.
For example this:
.. code::
{'a': {'b': [1,2], 'c': 3, 'd': [4,5,6]}}
becomes this:
.. code::
[
{'a': {'b': 1, 'c': 3, 'd': 4}},
{'a': {'b': 2, 'c': 3, 'd': 5}},
{'a': {'b': 1, 'c': 3, 'd': 6}},
{'a': {'b': 2, 'c': 3, 'd': 4}},
{'a': {'b': 1, 'c': 3, 'd': 5}},
{'a': {'b': 2, 'c': 3, 'd': 6}}
]
Parameters
----------
templates
A dictionary of configuration values that may contain lists.
Returns
-------
A list of dictionaries, each representing a single branch configuration.
"""
expanded_branches = []
for branch_template in templates:
branch_items = sorted(collapse_nested_dict(branch_template))
branch_items = [(k, v if isinstance(v, list) else [v]) for k, v in branch_items]
expanded_size = int(np.prod([len(v) for k, v in branch_items]))
new_branches = []
pointers = {k: 0 for k, _ in branch_items}
for _ in range(expanded_size):
new_branch = []
tick = True
for k, v in branch_items:
new_branch.append((k, v[pointers[k]]))
if tick:
i = pointers[k] + 1
if i < len(v):
tick = False
pointers[k] = i
else:
pointers[k] = 0
new_branches.append(new_branch)
expanded_branches.extend(new_branches)
final_branches = []
for branch_items in expanded_branches:
root: dict[str, Any] = {}
final_branches.append(root)
for k, v in branch_items:
current = root
*ks, k = k.split(".")
for sub_k in ks:
if sub_k in current:
current = current[sub_k]
else:
current[sub_k] = {}
current = current[sub_k]
current[k] = v
return final_branches
[docs]
def validate_artifact_path(artifact_path: str) -> None:
"""Validates that the path to the data artifact from the branches file exists.
The path specified in the configuration must be absolute
Parameters
----------
artifact_path
The path to the artifact.
Raises
------
FileNotFoundError
If the artifact path is not an absolute path or does not exist.
"""
path = Path(artifact_path)
if not path.is_absolute() or not path.exists():
raise FileNotFoundError(f"Cannot find artifact at path {path}")