from __future__ import annotations
import os
from typing import Dict, List, Literal, Tuple
import numpy as np
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from pydantic import PositiveInt
from iris.callbacks.callback_interface import Callback
from iris.io.dataclasses import IRImage, SegmentationMap
from iris.nodes.segmentation.multilabel_segmentation_interface import MultilabelSemanticSegmentationInterface
[docs]class ONNXMultilabelSegmentation(MultilabelSemanticSegmentationInterface):
"""Implementation of class which uses ONNX model to perform semantic segmentation maps prediction.
For more detailed model description check model card available in SEMSEG_MODEL_CARD.md file.
"""
[docs] class Parameters(MultilabelSemanticSegmentationInterface.Parameters):
"""Parameters class for ONNXMultilabelSegmentation objects."""
session: ort.InferenceSession
input_resolution: Tuple[PositiveInt, PositiveInt]
input_num_channels: Literal[1, 3]
__parameters_type__ = Parameters
[docs] @classmethod
def create_from_hugging_face(
cls,
model_name: str = "iris_semseg_upp_scse_mobilenetv2.onnx",
input_resolution: Tuple[PositiveInt, PositiveInt] = (640, 480),
input_num_channels: Literal[1, 3] = 3,
callbacks: List[Callback] = [],
) -> ONNXMultilabelSegmentation:
"""Create ONNXMultilabelSegmentation object with by downloading model from HuggingFace repository `MultilabelSemanticSegmentationInterface.HUGGING_FACE_REPO_ID`.
Args:
model_name (str, optional): Name of the ONNX model stored in HuggingFace repo. Defaults to "iris_semseg_upp_scse_mobilenetv2.onnx".
input_resolution (Tuple[PositiveInt, PositiveInt], optional): Neural Network input image resolution. Defaults to (640, 480).
input_num_channels (Literal[1, 3], optional): Neural Network input image number of channels. Defaults to 3.
callbacks (List[Callback], optional): List of algorithm callbacks. Defaults to [].
Returns:
ONNXMultilabelSegmentation: ONNXMultilabelSegmentation object.
"""
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
model_path = hf_hub_download(
repo_id=MultilabelSemanticSegmentationInterface.HUGGING_FACE_REPO_ID,
cache_dir=MultilabelSemanticSegmentationInterface.MODEL_CACHE_DIR,
filename=model_name,
)
return ONNXMultilabelSegmentation(model_path, input_resolution, input_num_channels, callbacks)
def __init__(
self,
model_path: str,
input_resolution: Tuple[PositiveInt, PositiveInt] = (640, 480),
input_num_channels: Literal[1, 3] = 3,
callbacks: List[Callback] = [],
) -> None:
"""Assign parameters.
Args:
model_path (str): Path to the ONNX model.
input_resolution (Tuple[PositiveInt, PositiveInt], optional): Neural Network input image resolution. Defaults to (640, 480).
input_num_channels (Literal[1, 3], optional): Neural Network input image number of channels. Defaults to 3.
callbacks (List[Callback], optional): List of algorithm callbacks. Defaults to [].
"""
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
super().__init__(
session=ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]),
input_resolution=input_resolution,
input_num_channels=input_num_channels,
callbacks=callbacks,
)
[docs] def run(self, image: IRImage) -> SegmentationMap:
"""Perform semantic segmentation prediction on an image.
Args:
image (IRImage): Infrared image object.
Returns:
SegmentationMap: Postprocessed model predictions.
"""
nn_input = self._preprocess(image.img_data)
prediction = self._forward(nn_input)
return self._postprocess(prediction, original_image_resolution=(image.width, image.height))
def _preprocess(self, image: np.ndarray) -> Dict[str, np.ndarray]:
"""Preprocess image so that inference with ONNX model is possible.
Args:
image (np.ndarray): Infrared image object.
Returns:
Dict[str, np.ndarray]: Dictionary with wrapped input name and image data {input_name: image_data}.
"""
nn_input = image.copy()
nn_input = self.preprocess(nn_input, self.params.input_resolution, self.params.input_num_channels)
return {self.params.session.get_inputs()[0].name: nn_input.astype(np.float32)}
def _forward(self, preprocessed_input: Dict[str, np.ndarray]) -> List[np.ndarray]:
"""Neural Network forward pass.
Args:
preprocessed_input (Dict[str, np.ndarray]): Inputs.
Returns:
List[np.ndarray]: Predictions.
"""
return self.params.session.run(None, preprocessed_input)
def _postprocess(self, nn_output: List[np.ndarray], original_image_resolution: Tuple[int, int]) -> SegmentationMap:
"""Postprocess model prediction and wrap it within SegmentationMap object for further processing.
Args:
nn_output (List[np.ndarray]): Neural Network output. Should be of length equal to 2.
original_image_resolution (Tuple[int, int]): Original image resolution used to resize predicted semantic segmentation maps.
Returns:
SegmentationMap: Postprocessed model predictions.
"""
segmaps_tensor = nn_output[0]
segmaps_tensor = self.postprocess_segmap(segmaps_tensor, original_image_resolution)
segmap = SegmentationMap(
predictions=segmaps_tensor, index2class=MultilabelSemanticSegmentationInterface.CLASSES_MAPPING
)
return segmap