Source code for vivarium.cluster_tools.dagger.config.parsing

"""
=====================
Workflow Step Parsing
=====================

YAML -> API kwargs translation for workflow steps. Each step type has a
parser that turns the raw YAML dict into the kwargs of its matching
interface API function. Parsers do YAML-shape validation (required
fields, unsupported ``args`` keys, ``command``/``type`` conflicts) inline.

Also exposes workflow-level entry points: :func:`parse_step_from_yaml`
and :func:`load_workflow_config`. ParsedStep -> YAML dict serialization
lives in
:mod:`vivarium.cluster_tools.dagger.config.serialization`.

"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Callable, cast

from vivarium.cluster_tools.dagger.config.config import (
    DEFAULT_MAX_ATTEMPTS,
    ParsedStep,
    ResourceConfig,
    WorkflowConfig,
)

_SIMULATION_SUPPORTED_ARGS: set[str] = {
    "model_specification",
    "branch_configuration",
    "artifact_path",
    "backup_freq",
    "sim_verbosity",
}
_PYTEST_SUPPORTED_ARGS: set[str] = {"path", "k", "runslow"}
_PYTHON_SUPPORTED_ARGS: set[str] = {"path", "positional_args", "keyword_args"}
_NOTEBOOK_SUPPORTED_ARGS: set[str] = {"path", "parameters", "output_path", "cwd"}


def _parse_common_step_fields(
    data: dict[str, Any],
    *,
    project: str,
    queue: str,
) -> tuple[str, ResourceConfig]:
    """Validate and extract the fields required by every step type.

    Checks that the step has a ``name``, a ``resources`` block, and that the
    block contains ``memory_gb``. Returns the step name and the constructed
    :class:`ResourceConfig`.
    """
    if "name" not in data:
        raise ValueError("Step: missing required field 'name'.")
    step_name = data["name"]
    if "resources" not in data:
        raise ValueError(f"Step '{step_name}': missing required field 'resources'.")
    resources_data = data["resources"]
    if "memory_gb" not in resources_data:
        raise ValueError(f"Step '{step_name}': missing required 'memory_gb' in 'resources'.")
    resources = ResourceConfig.from_dict(
        resources_data, workflow_project=project, workflow_queue=queue
    )
    return step_name, resources


def _get_required_args(
    data: dict[str, Any], step_name: str, required_fields: tuple[str, ...]
) -> dict[str, Any]:
    """Return ``data['args']`` after asserting the block and all ``required_fields`` are present."""
    if "args" not in data:
        raise ValueError(f"Step '{step_name}': missing required field 'args'.")
    args = cast(dict[str, Any], data["args"])
    for field in required_fields:
        if field not in args:
            raise ValueError(f"Step '{step_name}': missing required '{field}' in 'args'.")
    return args


[docs] def resolve_step_type(step_dict: dict[str, Any]) -> str: """Pick the step-type key for ``step_dict``. Dispatch rules: - A top-level ``command`` field always resolves to ``"bash"``; :func:`parse_bash_step_from_yaml` enforces the rest of the bash-step schema (including any conflicting ``type``). - Otherwise, an explicit ``type`` is used. - A step with neither ``command`` nor ``type`` is rejected. """ if "command" in step_dict: return "bash" if "name" not in step_dict: raise ValueError("Step: missing required field 'name'.") step_name = step_dict["name"] if "type" not in step_dict: raise ValueError( f"Step '{step_name}': must specify either a 'command' field or a " f"'type' field. Supported types: {sorted(STEP_TYPE_YAML_PARSERS)}." ) step_type: str = step_dict["type"] if step_type not in STEP_TYPE_YAML_PARSERS: raise ValueError( f"Step '{step_name}': unsupported type '{step_type}'. " f"Must be one of: {sorted(STEP_TYPE_YAML_PARSERS)}." ) return step_type
[docs] def parse_bash_step_from_yaml( data: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> dict[str, Any]: """Parse a raw bash-step YAML dict into API kwargs. The YAML form for a bash step requires a top-level ``command`` field. The optional ``type`` field, when present, must be ``"bash"``. No ``args:`` block is accepted. Examples -------- YAML configuration:: steps: - name: post_analysis command: python scripts/analyze.py --input /results environment: analysis_env resources: memory_gb: 20 runtime: "02:00:00" cores: 2 """ step_name, resources = _parse_common_step_fields(data, project=project, queue=queue) if "args" in data: raise ValueError(f"Step '{step_name}': bash steps do not support an 'args' block.") if "command" not in data: raise ValueError(f"Step '{step_name}': missing required field 'command'.") explicit_type = data.get("type", "bash") if explicit_type != "bash": raise ValueError( f"Step '{step_name}': cannot specify both 'command' and " f"'type: {explicit_type}'. When 'command' is set, 'type' " "must be omitted or set to 'bash'." ) return { "name": step_name, "resources": resources, "command": data["command"], "output_directory": output_directory, "environment": data.get("environment"), }
[docs] def parse_simulation_step_from_yaml( data: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> dict[str, Any]: """Parse a raw simulation-step YAML dict into API kwargs. Examples -------- YAML configuration:: steps: - name: model_sims type: simulation resources: memory_gb: 3 runtime: "24:00:00" args: model_specification: /path/to/model.yaml branch_configuration: /path/to/branches.yaml artifact_path: /path/to/artifact.hdf backup_freq: 1800 sim_verbosity: 1 """ step_name, resources = _parse_common_step_fields(data, project=project, queue=queue) args = _get_required_args( data, step_name, ("model_specification", "branch_configuration") ) _check_supported_args(args, step_name, _SIMULATION_SUPPORTED_ARGS) kwargs: dict[str, Any] = { "name": step_name, "resources": resources, "output_directory": output_directory, "environment": data.get("environment"), "model_specification": Path(args["model_specification"]).resolve(), "branch_configuration": Path(args["branch_configuration"]).resolve(), } if "artifact_path" in args: kwargs["artifact_path"] = Path(args["artifact_path"]).resolve() if "backup_freq" in args: kwargs["backup_freq"] = args["backup_freq"] if "sim_verbosity" in args: kwargs["sim_verbosity"] = args["sim_verbosity"] return kwargs
[docs] def parse_pytest_step_from_yaml( data: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> dict[str, Any]: """Parse a raw pytest-step YAML dict into API kwargs. Optional ``args`` keys: ``path``, ``k``, ``runslow``. At least one of ``path`` or ``k`` must be provided. ``path`` may be a single string or a list of strings. Examples -------- YAML configuration:: steps: - name: unit_tests type: pytest resources: memory_gb: 8 runtime: "01:00:00" cores: 4 args: path: tests/ k: "test_foo" runslow: true Multiple paths:: steps: - name: unit_and_integration type: pytest resources: memory_gb: 8 runtime: "01:00:00" args: path: - tests/unit - tests/integration """ step_name, resources = _parse_common_step_fields(data, project=project, queue=queue) args = data.get("args", {}) or {} _check_supported_args(args, step_name, _PYTEST_SUPPORTED_ARGS) kwargs: dict[str, Any] = { "name": step_name, "resources": resources, "output_directory": output_directory, } if "environment" in data: kwargs["environment"] = data["environment"] if "path" in args: raw_path = args["path"] if isinstance(raw_path, list): kwargs["path"] = [str(Path(p).resolve()) for p in raw_path] else: kwargs["path"] = str(Path(raw_path).resolve()) if "k" in args: kwargs["k"] = args["k"] if "runslow" in args: kwargs["runslow"] = args["runslow"] return kwargs
[docs] def parse_python_step_from_yaml( data: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> dict[str, Any]: """Parse a raw python-step YAML dict into API kwargs. Required ``args`` key: ``path`` (a ``.py`` script). Optional ``args`` keys: ``positional_args`` (list of scalars) and ``keyword_args`` (dict of identifier-keyed scalars). Examples -------- YAML configuration:: steps: - name: postprocess type: python resources: memory_gb: 8 runtime: "00:30:00" args: path: scripts/postprocess.py positional_args: - "foo" - "bar" keyword_args: input_dir: /mnt/results/model_29 verbose: true num_workers: 4 """ step_name, resources = _parse_common_step_fields(data, project=project, queue=queue) args = _get_required_args(data, step_name, ("path",)) _check_supported_args(args, step_name, _PYTHON_SUPPORTED_ARGS) kwargs: dict[str, Any] = { "name": step_name, "resources": resources, "output_directory": output_directory, "environment": data.get("environment"), "path": str(Path(args["path"]).resolve()), } if "positional_args" in args: kwargs["positional_args"] = args["positional_args"] if "keyword_args" in args: kwargs["keyword_args"] = args["keyword_args"] return kwargs
[docs] def parse_notebook_step_from_yaml( data: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> dict[str, Any]: """Parse a raw notebook-step YAML dict into API kwargs. Required ``args`` keys: ``path`` (input ``.ipynb``) and ``output_path`` (executed ``.ipynb``). Optional ``args`` keys: ``parameters`` (dict of identifier-keyed scalars injected into the notebook) and ``cwd`` (working directory for execution; defaults to the parent of ``path``). Examples -------- YAML configuration:: steps: - name: post_notebook_neonatal type: notebook resources: memory_gb: 20 runtime: "02:00:00" args: path: tests/model_notebooks/results/neonatal.ipynb output_path: /mnt/results/run_29/executed/neonatal.ipynb parameters: model_dir: /mnt/results/run_29 year: 2020 verbose: true """ step_name, resources = _parse_common_step_fields(data, project=project, queue=queue) args = _get_required_args(data, step_name, ("path", "output_path")) _check_supported_args(args, step_name, _NOTEBOOK_SUPPORTED_ARGS) kwargs: dict[str, Any] = { "name": step_name, "resources": resources, "output_directory": output_directory, "path": Path(args["path"]).resolve(), "output_path": Path(args["output_path"]).resolve(), } if "environment" in data: kwargs["environment"] = data["environment"] if "parameters" in args: kwargs["parameters"] = args["parameters"] if "cwd" in args: kwargs["cwd"] = Path(args["cwd"]).resolve() return kwargs
def _check_supported_args(args: dict[str, Any], step_name: str, supported: set[str]) -> None: """Raise ValueError if args contains keys not in ``supported``.""" unsupported = set(args) - supported if unsupported: raise ValueError( f"Step '{step_name}': unsupported args {sorted(unsupported)}. " f"Supported args: {sorted(supported)}." ) STEP_TYPE_YAML_PARSERS: dict[str, Callable[..., dict[str, Any]]] = { "bash": parse_bash_step_from_yaml, "simulation": parse_simulation_step_from_yaml, "pytest": parse_pytest_step_from_yaml, "python": parse_python_step_from_yaml, "notebook": parse_notebook_step_from_yaml, } """Maps each YAML ``step_type`` to its YAML -> API kwargs parser."""
[docs] def parse_step_from_yaml( raw: dict[str, Any], output_directory: Path, *, project: str, queue: str, ) -> ParsedStep: """Build a ParsedStep from a raw YAML step dict. Dispatches to the matching per-type parser to produce ``api_kwargs`` and tags the result with the resolved ``step_type`` for downstream dispatch (task building, YAML serialization). """ step_type = resolve_step_type(raw) api_kwargs = STEP_TYPE_YAML_PARSERS[step_type]( raw, output_directory, project=project, queue=queue ) return ParsedStep( step_type=step_type, name=api_kwargs["name"], api_kwargs=api_kwargs, )
[docs] def load_workflow_config( path: Path, *, name: str | None = None, project: str | None = None, queue: str | None = None, output_directory: Path | None = None, default_environment: str | None = None, max_attempts: int | None = None, ) -> WorkflowConfig: """Load a WorkflowConfig from YAML, merging CLI overrides. CLI arguments take precedence over values in the YAML file. Validates that ``name``, ``project``, ``queue``, and ``output_directory`` are provided by at least one source. Parameters ---------- path Path to the workflow YAML configuration file. name CLI override for the workflow name. project CLI override for the project field. queue CLI override for the queue field. output_directory CLI override for the output directory. default_environment CLI override for the default_environment field. max_attempts CLI override for the maximum number of Jobmon task attempts. Raises ------ ValueError If ``name``, ``project``, ``queue``, or ``output_directory`` cannot be resolved from either the YAML file or CLI arguments. """ workflow = WorkflowConfig.parse_yaml_file(path) resolved_name = name or workflow.get("name") resolved_project = project or workflow.get("project") resolved_queue = queue or workflow.get("queue") resolved_output_directory = output_directory or ( Path(workflow["output_directory"]).resolve() if "output_directory" in workflow else None ) if not resolved_name: raise ValueError( "Workflow name is required. Provide it in the config file or via --name/-n." ) if not resolved_project: raise ValueError( "Project is required. Provide it in the config file or via --project/-P." ) if not resolved_queue: raise ValueError( "Queue is required. Provide it in the config file or via --queue/-q." ) if not resolved_output_directory: raise ValueError( "Output directory is required. Provide it in the config file " "or via --output-directory/-o." ) steps = [ parse_step_from_yaml( raw, output_directory=resolved_output_directory, project=resolved_project, queue=resolved_queue, ) for raw in workflow["steps"] ] return WorkflowConfig( name=resolved_name, project=resolved_project, queue=resolved_queue, output_directory=resolved_output_directory, default_environment=default_environment or workflow.get("default_environment"), steps=steps, max_attempts=max_attempts or workflow.get("max_attempts", DEFAULT_MAX_ATTEMPTS), )