Source code for iris.pipelines.iris_pipeline

from __future__ import annotations

import os
import pydoc
import re
import traceback
from typing import Any, Dict, List, Literal, Optional, Union

import numpy as np
from pydantic import validator

import iris  # noqa: F401
import iris.nodes.validators.cross_object_validators
import iris.nodes.validators.object_validators
from iris._version import __version__
from iris.callbacks.pipeline_trace import PipelineCallTraceStorage, PipelineCallTraceStorageError
from iris.io.class_configs import Algorithm
from iris.io.dataclasses import IRImage
from iris.orchestration.environment import Environment
from iris.orchestration.error_managers import store_error_manager
from iris.orchestration.output_builders import (
    build_iris_pipeline_orb_output,
    build_simple_iris_pipeline_debugging_output,
    build_simple_iris_pipeline_orb_output,
)
from iris.orchestration.pipeline_dataclasses import PipelineMetadata, PipelineNode
from iris.orchestration.validators import pipeline_config_duplicate_node_name_check, pipeline_metadata_version_check
from iris.pipelines.base_pipeline import BasePipeline, load_yaml_config
from iris.utils.base64_encoding import base64_decode_str


[docs] class IRISPipeline(BasePipeline): """ Implementation of a fully configurable iris recognition pipeline. Inherits shared logic from BasePipeline and implements input/output specifics. """ DEFAULT_PIPELINE_CFG_PATH = os.path.join(os.path.dirname(__file__), "confs", "pipeline.yaml") PACKAGE_VERSION = __version__ DEBUGGING_ENVIRONMENT = Environment( pipeline_output_builder=build_simple_iris_pipeline_debugging_output, error_manager=store_error_manager, disabled_qa=[ iris.nodes.validators.object_validators.Pupil2IrisPropertyValidator, iris.nodes.validators.object_validators.OffgazeValidator, iris.nodes.validators.object_validators.OcclusionValidator, iris.nodes.validators.object_validators.IsPupilInsideIrisValidator, iris.nodes.validators.object_validators.SharpnessValidator, iris.nodes.validators.object_validators.IsMaskTooSmallValidator, iris.nodes.validators.cross_object_validators.EyeCentersInsideImageValidator, iris.nodes.validators.cross_object_validators.ExtrapolatedPolygonsInsideImageValidator, ], call_trace_initialiser=PipelineCallTraceStorage.initialise, ) ORB_ENVIRONMENT = Environment( pipeline_output_builder=build_iris_pipeline_orb_output, error_manager=store_error_manager, call_trace_initialiser=PipelineCallTraceStorage.initialise, )
[docs] class Parameters(Algorithm.Parameters): """IRISPipeline parameters, all derived from the input `config`.""" metadata: PipelineMetadata pipeline: List[PipelineNode] _config_duplicate_node_name_check = validator("pipeline", allow_reuse=True)( pipeline_config_duplicate_node_name_check ) @validator("metadata", allow_reuse=True) def _version_check(cls, v, values, **kwargs): # support dynamic subclassing pipeline_cls = getattr(cls, "__outer_class__", IRISPipeline) # assigned in BasePipeline.__init_subclass__ pipeline_metadata_version_check(cls, v.iris_version, values, expected_version=pipeline_cls.PACKAGE_VERSION) return v
__parameters_type__ = Parameters def __init__( self, config: Union[Dict[str, Any], Optional[str]] = None, env: Environment = Environment( pipeline_output_builder=build_simple_iris_pipeline_orb_output, error_manager=store_error_manager, call_trace_initialiser=PipelineCallTraceStorage.initialise, ), ) -> None: """ Initialize IRISPipeline with config and environment. Args: config (Union[Dict[str, Any], Optional[str]]): Pipeline config dict or YAML string. env (Environment): Pipeline environment. """ deserialized_config = self.load_config(config) if isinstance(config, str) or config is None else config super().__init__(deserialized_config, env)
[docs] def estimate(self, img_data: np.ndarray, eye_side: Literal["left", "right"], *args, **kwargs) -> Any: """Wrap the `run` method to match the Orb system AI models call interface. Args: img_data (np.ndarray): Input image data. eye_side (Literal["left", "right"]): Eye side. *args: Optional positional arguments for extensibility. **kwargs: Optional keyword arguments for extensibility. Returns: Any: Output created by builder specified in environment.pipeline_output_builder. """ return self.run(img_data, eye_side, *args, **kwargs)
[docs] def run(self, img_data: np.ndarray, eye_side: Literal["left", "right"], *args, **kwargs) -> Any: """ Wrap the `run` method to match the Orb system AI models call interface. Args: img_data (np.ndarray): Input Infrared image data. eye_side (Literal["left", "right"]): Eye side. *args: Optional positional arguments for extensibility. **kwargs: Optional keyword arguments for extensibility. Returns: Any: Output created by builder specified in environment.pipeline_output_builder. """ pipeline_input = {"img_data": img_data, "eye_side": eye_side} return super().run(pipeline_input, *args, **kwargs)
def _handle_input(self, pipeline_input: Any, *args, **kwargs) -> None: """ Write the IR image input to the call trace. Args: pipeline_input (dict): Should contain 'img_data' and 'eye_side'. *args: Optional positional arguments for extensibility. **kwargs: Optional keyword arguments for extensibility. """ ir_image = IRImage(img_data=pipeline_input["img_data"], eye_side=pipeline_input["eye_side"]) self.call_trace.write_input(ir_image) def _handle_output(self, *args, **kwargs) -> Any: """ Build and return the pipeline output using the environment's output builder. Args: *args: Optional positional arguments for extensibility. **kwargs: Optional keyword arguments for extensibility. Returns: Any: Output as built by the pipeline_output_builder. """ return self.env.pipeline_output_builder(self.call_trace) def _handle_node_error(self, node: PipelineNode, error: Exception) -> bool: """ Custom error handling for node execution in IRISPipeline. Handles PipelineCallTraceStorageError and KeyError by checking if the node's class is in disabled_qa. Args: node (PipelineNode): The node where the error occurred. error (Exception): The exception raised during node execution. Returns: bool: True if the error was skipped, False otherwise. """ # If the error is a PipelineCallTraceStorageError or KeyError and the node is in disabled_qa, skip if isinstance(error, (PipelineCallTraceStorageError, KeyError)): node_class = pydoc.locate(node.algorithm.class_name) if node_class in getattr(self.env, "disabled_qa", []): return True # Otherwise, treat as a missing node error self.env.error_manager(self.call_trace, ValueError(f"Could not find node {node.name}.")) return False else: # For all other errors, use the default error manager self.env.error_manager(self.call_trace, error) return False
[docs] @classmethod def load_config(cls, config: Optional[str]) -> Dict[str, Any]: """ Load and deserialize the pipeline configuration. Args: config (Optional[str]): YAML string or None for default config. Returns: Dict[str, Any]: Deserialized config dict. """ if config is None or config == "": # noqa config = cls.DEFAULT_PIPELINE_CFG_PATH deserialized_config = load_yaml_config(config) return deserialized_config
[docs] @classmethod def load_from_config(cls, config: str) -> Dict[str, Union["IRISPipeline", Optional[Dict[str, Any]]]]: """ Given an iris config string in base64, initialise an IRISPipeline with this config. Args: config (str): Base64-encoded config string. Returns: Dict[str, Union[IRISPipeline, Optional[Dict[str, Any]]]]: Initialised pipeline and error (if any). """ error = None iris_pipeline = None try: decoded_config_str = base64_decode_str(config) iris_pipeline = cls(config=decoded_config_str) except Exception as exception: error = { "error_type": type(exception).__name__, "message": str(exception), "traceback": "".join(traceback.format_tb(exception.__traceback__)), } return {"agent": iris_pipeline, "error": error}
[docs] def update_config(self, config: str) -> None: """Update the pipeline configuration based on the provided base64-encoded string. Args: config (str): Base64-encoded string of the new configuration. """ if not re.fullmatch(r"[A-Za-z0-9+/]*={0,2}", config): raise ValueError("Invalid base64-encoded string") decoded_config_str = base64_decode_str(config) config_dict = self.load_config(decoded_config_str) params = self.__parameters_type__(**config_dict) self._check_pipeline_coherency(params) self.params = params self.nodes = self._instanciate_nodes() self.call_trace = self.env.call_trace_initialiser(nodes=self.nodes, pipeline_nodes=self.params.pipeline)