Skip to content

ZenML Integration

Config generation, model promotion lifecycle, and reusable pipeline steps.

Config Generation

fair.zenml.config

Promotion

fair.zenml.promotion

Steps

fair.zenml.steps

Reusable ZenML steps for model developers.

load_model resolves a model from the ZenML artifact store using either a direct artifact version ID or a URI fallback. The materializer registered at training time handles deserialization — PyTorch, Keras, TensorFlow, or any custom materializer works transparently.

load_model(model_uri, zenml_artifact_version_id='')

Resolve model from ZenML artifact store. Framework-agnostic via materializer.

Source code in fair/zenml/steps.py
@step
def load_model(
    model_uri: str,
    zenml_artifact_version_id: str = "",
) -> Any:
    """Resolve model from ZenML artifact store. Framework-agnostic via materializer."""
    client = Client()
    if zenml_artifact_version_id:
        art = client.get_artifact_version(zenml_artifact_version_id)
    else:
        results = client.list_artifact_versions(uri=model_uri)
        if not results:
            msg = f"No artifact found for URI: {model_uri}"
            raise RuntimeError(msg)
        art = results[0]
    return art.load()

Instrumentation

fair.zenml.instrumentation

mlflow_training_context(hyperparameters, model_name=None, base_model_id=None, dataset_id=None)

Context manager that instruments a training step with MLflow.

Handles autolog, param logging, tag setting, and wall-clock timing. Contributors use this instead of manual MLflow calls.

Source code in fair/zenml/instrumentation.py
@contextmanager
def mlflow_training_context(
    hyperparameters: dict[str, Any],
    model_name: str | None = None,
    base_model_id: str | None = None,
    dataset_id: str | None = None,
):
    """Context manager that instruments a training step with MLflow.

    Handles autolog, param logging, tag setting, and wall-clock timing.
    Contributors use this instead of manual MLflow calls.
    """
    import mlflow

    mlflow.autolog()  # ty: ignore[possibly-missing-attribute]
    mlflow.log_params(  # ty: ignore[possibly-missing-attribute]
        {k: v for k, v in hyperparameters.items() if not isinstance(v, (dict, list))}
    )

    tags: dict[str, str] = {}
    if model_name:
        tags["fair.model_name"] = model_name
    if base_model_id:
        tags["fair.base_model"] = base_model_id
    if dataset_id:
        tags["fair.dataset"] = dataset_id
    if tags:
        mlflow.set_tags(tags)  # ty: ignore[possibly-missing-attribute]

    wall_start = time.perf_counter()
    yield
    wall_seconds = time.perf_counter() - wall_start
    log_training_wall_time(wall_seconds)

log_evaluation_results(metrics)

Log evaluation metrics to both MLflow and ZenML fair-prefixed metadata.

Source code in fair/zenml/instrumentation.py
def log_evaluation_results(metrics: dict[str, Any]) -> None:
    """Log evaluation metrics to both MLflow and ZenML fair-prefixed metadata."""
    import mlflow

    scalar_metrics = {k: v for k, v in metrics.items() if isinstance(v, (int, float))}
    mlflow.log_metrics(scalar_metrics)  # ty: ignore[possibly-missing-attribute]
    log_fair_metrics(metrics)

Metrics

fair.zenml.metrics