[docs]classMultilabelSemanticSegmentationInterface(Algorithm):"""Interface of a model semantic segmentation prediction trained with multilabel labels."""HUGGING_FACE_REPO_ID="Worldcoin/iris-semantic-segmentation"MODEL_CACHE_DIR=os.path.join(os.path.dirname(os.path.abspath(__file__)),"assets")CLASSES_MAPPING={0:"eyeball",1:"iris",2:"pupil",3:"eyelashes",}
[docs]@classmethoddefcreate_from_hugging_face(cls)->MultilabelSemanticSegmentationInterface:"""Abstract function just to make sure all subclasses implement it. Raises: RuntimeError: Raised if subclass doesn't implement that class method. Returns: MultilabelSemanticSegmentationInterface: MultilabelSemanticSegmentationInterface subclass object. """raiseRuntimeError(f"`create_from_hugging_face` function hasn't been implemented for {cls.__name__} subclass.")
[docs]defpreprocess(self,image:np.ndarray,input_resolution:Tuple[int,int],nn_input_channels:int)->np.ndarray:"""Preprocess image before running a model inference. Args: image (np.ndarray): Image to preprocess. input_resolution (Tuple[int, int]): A model input resolution. nn_input_channels (int): A model input channels. Returns: np.ndarray: Preprocessed image. """nn_input=cv2.resize(image.astype(float),input_resolution)nn_input=np.divide(nn_input,255)# Replicates torchvision's ToTensornn_input=np.expand_dims(nn_input,axis=-1)nn_input=np.tile(nn_input,(1,1,nn_input_channels))# Replicates torchvision's Normalizationmeans=np.array([0.485,0.456,0.406])ifnn_input_channels==3else0.5stds=np.array([0.229,0.224,0.225])ifnn_input_channels==3else0.5nn_input-=meansnn_input/=stdsnn_input=nn_input.transpose(2,0,1)nn_input=np.expand_dims(nn_input,axis=0)returnnn_input