Source code for iris.io.validators

import re
from typing import Any, Callable, Dict, Iterable, List

import numpy as np
from pydantic import fields

from iris.io.errors import IRISPipelineError

# ----- validators -----


[docs]def is_odd(cls: type, v: int, field: fields.ModelField) -> int: """Check that kernel size are odd numbers. Args: cls (type): Class type. v (int): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Exception raised if number isn't odd. Returns: int: `v` sent for further processing. """ if (v % 2) == 0: raise ValueError(f"{cls.__name__}: {field.name} must be odd numbers.") return v
[docs]def is_uint8(cls: type, v: np.ndarray, field: fields.ModelField) -> np.ndarray: """Check if np array contains only uint8 values.""" values_check = not (np.all(v >= 0) and np.all(v <= 255)) if values_check or v.dtype != np.uint8: raise ValueError(f"{cls.__name__}: {field.name} must be of uint8 type. Received {v.dtype}") return v
[docs]def is_binary(cls: type, v: np.ndarray, field: fields.ModelField) -> np.ndarray: """Check if array has only boolean values, i.e. is binary. Args: cls (type): Class type. v (np.ndarray): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Exception raised if array doesn't contain bool datatypes. Returns: np.ndarray: `v` sent for further processing. """ if v.dtype != np.dtype("bool"): raise ValueError(f"{cls.__name__}: {field.name} must be binary. got dtype {v.dtype}") return v
[docs]def is_list_of_points(cls: type, v: np.ndarray, field: fields.ModelField) -> np.ndarray: """Check if np.ndarray has shape (_, 2). Args: cls (type): Class type. v (np.ndarray): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Exception raised if array doesn't contain 2D points. Returns: np.ndarray: `v` sent for further processing. """ if len(v.shape) != 2 or v.shape[1] != 2: raise ValueError(f"{cls.__name__}: {field.name} must have shape (_, 2).") return v
[docs]def is_not_empty(cls: type, v: List[Any], field: fields.ModelField) -> List[Any]: """Check that both inputs are not empty. Args: cls (type): Class type. v (List[Any]): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Exception raised if list is empty. Returns: List[Any]: `v` sent for further processing. """ if len(v) == 0: raise ValueError(f"{cls.__name__}: {field.name} list cannot be empty.") return v
[docs]def is_not_zero_sum(cls: type, v: Any, field: fields.ModelField) -> Any: """Check that both inputs are not empty. Args: cls (type): Class type. v (Any): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Raised if v doesn't sum to 0. Returns: Any: `v` sent for further processing. """ if np.sum(v) == 0: raise ValueError(f"{cls.__name__}: {field.name} sum cannot be zero.") return v
[docs]def are_all_positive(cls: type, v: Any, field: fields.ModelField) -> Any: """Check that all values are positive. Args: cls (type): Class type. v (Any): Value to check. field (fields.ModelField): Field descriptor. Raises: ValueError: Raise if not all values in are positive. Returns: Any: `v` sent for further processing. """ if isinstance(v, Iterable): if not np.array([value >= 0 for value in v]).all(): raise ValueError(f"{cls.__name__}: all {field.name} must be positive. Received {v}") elif v < 0.0: raise ValueError(f"{cls.__name__}: {field.name} must be positive. Received {v}") return v
[docs]def iris_code_version_check(cls: type, v: str, field: fields.ModelField) -> str: """Check if the version provided in the input config matches the current iris.__version__.""" if not re.match(r"v[\d]+\.[\d]+$", v): raise IRISPipelineError(f"Wrong iris code version. Expected standard version nuber, received {v}") return v
[docs]def to_dtype_float32(cls: type, v: np.ndarray, field: fields.ModelField) -> np.ndarray: """Convert input np.ndarray to dtype np.float32. Args: cls (type): Class type. v (np.ndarray): Value to convert field (fields.ModelField): Field descriptor. Returns: np.ndarray: `v` sent for further processing. """ return v.astype(np.float32)
# ----- root_validators -----
[docs]def is_valid_bbox(cls: type, values: Dict[str, float]) -> Dict[str, float]: """Check that the bounding box is valid.""" if values["x_min"] >= values["x_max"] or values["y_min"] >= values["y_max"]: raise ValueError( f'{cls.__name__}: invalid bbox. x_min={values["x_min"]}, x_max={values["x_max"]},' f' y_min={values["y_min"]}, y_max={values["y_max"]}' ) return values
# ----- parametrized validators -----
[docs]def is_array_n_dimensions(nb_dimensions: int) -> Callable: """Create a pydantic validator checking if an array is n-dimensional. Args: nb_dimensions (int): number of dimensions the array must have Returns: Callable: the validator. """ def validator(cls: type, v: np.ndarray, field: fields.ModelField) -> np.ndarray: """Check if the array has the right number of dimensions.""" if len(v.shape) != nb_dimensions and (v.shape != (0,) or nb_dimensions != 0): raise ValueError( f"{cls.__name__}: wrong number of dimensions for {field.name}. " f"Expected {nb_dimensions}, got {len(v.shape)}" ) return v return validator
# ----- parametrized root_validators -----
[docs]def are_lengths_equal(field1: str, field2: str) -> Callable: """Create a pydantic validator checking if the two fields have the same length. Args: field1 (str): name of the first field field2 (str): name of the first field Returns: Callable: the validator. """ def __root_validator(cls: type, values: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Check if len(field1) equals len(field2).""" if len(values[field1]) != len(values[field2]): raise ValueError( f"{cls.__name__}: {field1} and {field2} length mismatch, " f"resp. {len(values[field1])} and {len(values[field2])}" ) return values return __root_validator
[docs]def are_shapes_equal(field1: str, field2: str) -> Callable: """Create a pydantic validator checking if the two fields have the same shape. Args: field1 (str): name of the first field field2 (str): name of the first field Returns: Callable: the validator. """ def __root_validator(cls: type, values: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """Check if field1.shape equals field2.shape.""" if values[field1].shape != values[field2].shape: raise ValueError(f"{cls.__name__}: {field1} and {field2} shape mismatch.") return values return __root_validator
[docs]def are_all_shapes_equal(field1: str, field2: str) -> Callable: """Create a pydantic validator checking if two lists of array have the same shape per element. This function creates a pydantic validator for two lists of np.ndarrays which checks if they have the same length, and if all of their element have the same shape one by one. Args: field1 (str): name of the first field field2 (str): name of the first field Returns: Callable: the validator. """ def __root_validator(cls: type, values: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """Check if len(field1) equals len(field2) and if every element have the same shape.""" shapes_field_1 = [element.shape for element in values[field1]] shapes_field_2 = [element.shape for element in values[field2]] if len(values[field1]) != len(values[field2]) or shapes_field_1 != shapes_field_2: raise ValueError( f"{cls.__name__}: {field1} and {field2} shape mismatch, resp. {shapes_field_1} and {shapes_field_2}." ) return values return __root_validator