Source code for iris.orchestration.output_builders

import traceback
from collections.abc import Mapping
from typing import Any, Dict, List, Optional

import numpy as np

from iris._version import __version__
from iris.callbacks.pipeline_trace import PipelineCallTraceStorage
from iris.io.dataclasses import ImmutableModel, OutputFieldSpec


def _nested_safe_serialize(obj: Any) -> Any:
    """
    Apply __safe_serialize to obj, handling nested dicts by recursing into values.
    Lists and tuples are handled by __safe_serialize itself.
    """
    if obj is None:
        return None
    # Handle mappings by serializing each value
    if isinstance(obj, Mapping):
        return {k: _nested_safe_serialize(v) for k, v in obj.items()}
    # Fallback to the existing helper (handles ImmutableModel, list, tuple)
    return __safe_serialize(obj)


def _build_from_spec(call_trace: PipelineCallTraceStorage, spec: List[OutputFieldSpec]) -> Dict[str, Any]:
    """
    Generic builder that constructs an output dict based on a list of OutputFieldSpec.

    Args:
        call_trace (PipelineCallTraceStorage): The pipeline call trace storage object.
        spec (List[OutputFieldSpec]): A list of OutputFieldSpec defining how to extract and optionally serialize each field.

    Returns:
        Dict[str, Any]: A dict mapping each spec.key to the (optionally serialized) extracted value.
    """
    output: Dict[str, Any] = {}
    for field in spec:
        # Extract the raw value using the provided extractor function
        val = field.extractor(call_trace)
        # If requested, wrap complex objects in a safe-serialize step
        if field.safe_serialize:
            val = _nested_safe_serialize(val)
        output[field.key] = val
    return output


def __safe_serialize(object: Optional[Any]) -> Optional[Any]:
    """Serialize an object.

    Args:
        object (Optional[Any]): Object to be serialized.

    Raises:
        NotImplementedError: Raised if object is not serializable.

    Returns:
        Optional[Any]: Serialized object.
    """
    if object is None:
        return None
    elif isinstance(object, ImmutableModel):
        return object.serialize()
    elif isinstance(object, (list, tuple)):
        return [__safe_serialize(sub_object) for sub_object in object]
    elif isinstance(object, np.ndarray):
        return object.tolist()
    else:
        raise NotImplementedError(f"Object of type {type(object)} is not serializable.")


def __get_iris_pipeline_metadata(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]:
    """Produce metadata output from a call_trace.

    Args:
        call_trace (PipelineCallTraceStorage): Pipeline call trace.

    Returns:
        Dict[str, Any]: Metadata dictionary.
    """
    ir_image = call_trace.get_input()

    return {
        "iris_version": __version__,
        "image_size": (ir_image.width, ir_image.height),
        "eye_side": ir_image.eye_side,
        "eye_centers": __safe_serialize(call_trace.get("eye_center_estimation")),
        "pupil_to_iris_property": __safe_serialize(call_trace.get("pupil_to_iris_property_estimation")),
        "offgaze_score": __safe_serialize(call_trace.get("offgaze_estimation")),
        "eye_orientation": __safe_serialize(call_trace.get("eye_orientation")),
        "occlusion90": __safe_serialize(call_trace.get("occlusion90_calculator")),
        "occlusion30": __safe_serialize(call_trace.get("occlusion30_calculator")),
        "iris_bbox": __safe_serialize(call_trace.get("bounding_box_estimation")),
        "sharpness_score": __safe_serialize(call_trace.get("sharpness_estimation")),
    }


def __get_error(call_trace: PipelineCallTraceStorage) -> Optional[Dict[str, Any]]:
    """Produce error output from a call_trace.

    Args:
        call_trace (PipelineCallTraceStorage): Pipeline call trace.

    Returns:
        Optional[Dict[str, Any]]: Optional error dictionary if such occured.
    """
    exception = call_trace.get_error()
    error = None

    if isinstance(exception, Exception):
        error = {
            "error_type": type(exception).__name__,
            "message": str(exception),
            "traceback": "".join(traceback.format_tb(exception.__traceback__)),
        }

    return error


def __get_multiframe_aggregation_metadata(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]:
    """Produce multiframe aggregation metadata output from a call_trace.

    Args:
        call_trace (PipelineCallTraceStorage): Pipeline call trace.

    Returns:
        Dict[str, Any]: Metadata dictionary.
    """
    templates = call_trace.get_input()

    return {
        "iris_version": __version__,
        "templates_count": len(templates),
    }


# =============================================================================
# Specs for different output variants
# =============================================================================

# Simple ORB output: raw iris_template, error info, and metadata
IRIS_PIPE_SIMPLE_ORB_OUTPUT_SPEC = [
    OutputFieldSpec(key="error", extractor=__get_error, safe_serialize=False),
    OutputFieldSpec(key="iris_template", extractor=lambda ct: ct.get("encoder"), safe_serialize=False),
    OutputFieldSpec(key="metadata", extractor=__get_iris_pipeline_metadata, safe_serialize=False),
]

IRIS_PIPE_ORB_OUTPUT_SPEC = [
    OutputFieldSpec(key="error", extractor=__get_error, safe_serialize=False),
    OutputFieldSpec(key="iris_template", extractor=lambda ct: ct.get("encoder"), safe_serialize=True),
    OutputFieldSpec(key="metadata", extractor=__get_iris_pipeline_metadata, safe_serialize=False),
]

# Debugging output: includes various intermediate pipeline results
IRIS_PIPE_DEBUG_OUTPUT_SPEC = [
    OutputFieldSpec(key="iris_template", extractor=lambda ct: ct.get("encoder"), safe_serialize=False),
    OutputFieldSpec(key="metadata", extractor=__get_iris_pipeline_metadata, safe_serialize=False),
    OutputFieldSpec(key="segmentation_map", extractor=lambda ct: ct.get("segmentation"), safe_serialize=True),
    OutputFieldSpec(
        key="segmentation_binarization",
        extractor=lambda ct: {
            "geometry": None if ct.get("segmentation_binarization") is None else ct.get("segmentation_binarization")[0],
            "noise": None if ct.get("segmentation_binarization") is None else ct.get("segmentation_binarization")[1],
        },
        safe_serialize=True,
    ),
    OutputFieldSpec(
        key="extrapolated_polygons", extractor=lambda ct: ct.get("geometry_estimation"), safe_serialize=True
    ),
    OutputFieldSpec(key="normalized_iris", extractor=lambda ct: ct.get("normalization"), safe_serialize=True),
    OutputFieldSpec(key="iris_response", extractor=lambda ct: ct.get("filter_bank"), safe_serialize=True),
    OutputFieldSpec(
        key="iris_response_refined", extractor=lambda ct: ct.get("iris_response_refinement"), safe_serialize=True
    ),
    OutputFieldSpec(key="error", extractor=__get_error, safe_serialize=False),
]

MULTIFRAME_AGG_ORB_OUTPUT_SPEC = [
    OutputFieldSpec(key="error", extractor=__get_error, safe_serialize=False),
    OutputFieldSpec(
        key="iris_template",
        extractor=lambda ct: ct.get("templates_aggregation", [None, None])[0]
        if ct.get("templates_aggregation") is not None
        else None,
        safe_serialize=True,
    ),
    OutputFieldSpec(
        key="weights",
        extractor=lambda ct: ct.get("templates_aggregation", [None, None])[1]
        if ct.get("templates_aggregation") is not None
        else None,
        safe_serialize=True,
    ),
    OutputFieldSpec(key="metadata", extractor=__get_multiframe_aggregation_metadata, safe_serialize=False),
]

MULTIFRAME_AGG_SIMPLE_ORB_OUTPUT_SPEC = [
    OutputFieldSpec(key="error", extractor=__get_error, safe_serialize=False),
    OutputFieldSpec(
        key="iris_template",
        extractor=lambda ct: ct.get("templates_aggregation", [None, None])[0]
        if ct.get("templates_aggregation") is not None
        else None,
        safe_serialize=False,
    ),
    OutputFieldSpec(
        key="weights",
        extractor=lambda ct: ct.get("templates_aggregation", [None, None])[1]
        if ct.get("templates_aggregation") is not None
        else None,
        safe_serialize=False,
    ),
    OutputFieldSpec(key="metadata", extractor=__get_multiframe_aggregation_metadata, safe_serialize=False),
]

# =============================================================================
# Builder functions leveraging the generic engine
# =============================================================================


[docs] def build_simple_iris_pipeline_orb_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """Construct simple ORB output: raw iris_template, error, and metadata.""" return _build_from_spec(call_trace, IRIS_PIPE_SIMPLE_ORB_OUTPUT_SPEC)
[docs] def build_iris_pipeline_orb_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """Construct ORB output with serialized iris_template.""" return _build_from_spec(call_trace, IRIS_PIPE_ORB_OUTPUT_SPEC)
[docs] def build_simple_iris_pipeline_debugging_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """Construct debugging output with intermediate results (raw values).""" return _build_from_spec(call_trace, IRIS_PIPE_DEBUG_OUTPUT_SPEC)
[docs] def build_iris_pipeline_debugging_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """ Construct full debugging output: wrap simple_debugging and ensure the iris_template is also safely serialized. """ output = build_simple_iris_pipeline_debugging_output(call_trace) output["iris_template"] = __safe_serialize(output.get("iris_template")) return output
[docs] def build_aggregation_multiframe_orb_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """Construct multiframe aggregation ORB output with safe serialization.""" return _build_from_spec(call_trace, MULTIFRAME_AGG_ORB_OUTPUT_SPEC)
[docs] def build_simple_multiframe_aggregation_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]: """Construct simple multiframe aggregation output (raw values).""" return _build_from_spec(call_trace, MULTIFRAME_AGG_SIMPLE_ORB_OUTPUT_SPEC)