Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 15 additions & 49 deletions src/model_api/models/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from model_api.models.result import ClassificationResult, Label

from .model import Model
from .types import BooleanValue, ListValue, NumericalValue, StringValue
from .parameters import ParameterRegistry
from .utils import load_labels

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,26 +65,19 @@ def __init__(
self.image_blob_names = self._get_inputs()
self.image_blob_name = self.image_blob_names[0]
self.nscthw_layout = "NSCTHW" in self.inputs[self.image_blob_name].layout
self.labels: list[str]
self.path_to_labels: str
self.mean_values: list[int | float]
self.pad_value: int
self.resize_type: str
self.reverse_input_channels: bool
self.scale_values: list[int | float]

if self.nscthw_layout:
self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[self.image_blob_name].shape
else:
self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape
self.resize = RESIZE_TYPES[self.resize_type]
self.resize = RESIZE_TYPES[self.params.resize_type]
self.input_transform = InputTransform(
self.reverse_input_channels,
self.mean_values,
self.scale_values,
self.params.reverse_input_channels,
self.params.mean_values,
self.params.scale_values,
)
if self.path_to_labels:
self.labels = load_labels(self.path_to_labels)
if self.params.path_to_labels:
self._labels = load_labels(self.params.path_to_labels)

@property
def clip_size(self) -> int:
Expand All @@ -94,39 +87,11 @@ def clip_size(self) -> int:
def parameters(cls) -> dict[str, Any]:
parameters = super().parameters()
parameters.update(
{
"labels": ListValue(description="List of class labels"),
"path_to_labels": StringValue(
description="Path to file with labels. Overrides the labels, if they sets via 'labels' parameter",
),
"mean_values": ListValue(
description=(
"Normalization values, which will be subtracted from image channels "
"for image-input layer during preprocessing"
),
default_value=[],
),
"pad_value": NumericalValue(
int,
min=0,
max=255,
description="Pad value for resize_image_letterbox embedded into a model",
default_value=0,
),
"resize_type": StringValue(
default_value="standard",
choices=tuple(RESIZE_TYPES.keys()),
description="Type of input image resizing",
),
"reverse_input_channels": BooleanValue(
default_value=False,
description="Reverse the input channel order",
),
"scale_values": ListValue(
default_value=[],
description="Normalization values, which will divide the image channels for image-input layer",
),
},
ParameterRegistry.merge(
ParameterRegistry.LABELS,
ParameterRegistry.IMAGE_RESIZE,
ParameterRegistry.IMAGE_PREPROCESSING,
),
)
return parameters

Expand Down Expand Up @@ -193,7 +158,7 @@ def preprocess(
"original_shape": inputs.shape,
"resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w),
}
resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.pad_value) for frame in inputs]
resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.params.pad_value) for frame in inputs]
np_frames = self._change_layout(
[self.input_transform(inputs) for inputs in resized_inputs],
)
Expand Down Expand Up @@ -222,8 +187,9 @@ def postprocess(
"""Post-process."""
logits = next(iter(outputs.values())).squeeze()
index = np.argmax(logits)
labels = self.params.labels
return ClassificationResult(
[Label(int(index), self.labels[index], logits[index])],
[Label(int(index), labels[index], logits[index])],
np.ndarray(0),
np.ndarray(0),
np.ndarray(0),
Expand Down
48 changes: 12 additions & 36 deletions src/model_api/models/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import numpy as np

from model_api.models.image_model import ImageModel
from model_api.models.parameters import ParameterRegistry
from model_api.models.result import AnomalyResult
from model_api.models.types import ListValue, NumericalValue, StringValue

if TYPE_CHECKING:
from model_api.adapters.inference_adapter import InferenceAdapter
Expand Down Expand Up @@ -67,11 +67,6 @@ def __init__(
) -> None:
super().__init__(inference_adapter, configuration, preload)
self._check_io_number(1, (1, 4))
self.normalization_scale: float
self.image_threshold: float
self.pixel_threshold: float
self.task: str
self.labels: list[str]

def preprocess(self, inputs: np.ndarray) -> list[dict]:
"""Data preprocess method for Anomalib models.
Expand Down Expand Up @@ -103,7 +98,7 @@ def preprocess(self, inputs: np.ndarray) -> list[dict]:
else:
resized_shape = (self.w, self.h, self.c)
# For fixed models, use standard preprocessing
if self.embedded_processing:
if self.params.embedded_processing:
processed_image = inputs[None]
else:
# Resize image to expected model input dimensions
Expand Down Expand Up @@ -148,16 +143,17 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
anomaly_map = predictions.squeeze()
npred_score = anomaly_map.reshape(-1).max()

pred_label = self.labels[1] if npred_score > self.image_threshold else self.labels[0]
labels_list = self.params.labels
pred_label = labels_list[1] if npred_score > self.params.image_threshold else labels_list[0]

assert anomaly_map is not None
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)
pred_mask = (anomaly_map >= self.params.pixel_threshold).astype(np.uint8)
anomaly_map = self._normalize(anomaly_map, self.params.pixel_threshold)

# normalize
npred_score = self._normalize(npred_score, self.image_threshold)
npred_score = self._normalize(npred_score, self.params.image_threshold)

if pred_label == self.labels[0]: # normal
if pred_label == labels_list[0]: # normal
npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly
pred_score = npred_score.item()
else:
Expand All @@ -180,7 +176,7 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
(meta["original_shape"][1], meta["original_shape"][0]),
)

if self.task == "detection":
if self.params.task == "detection":
pred_boxes = self._get_boxes(pred_mask)

return AnomalyResult(
Expand All @@ -194,33 +190,13 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
@classmethod
def parameters(cls) -> dict:
parameters = super().parameters()
parameters.update(
{
"image_threshold": NumericalValue(
description="Image threshold",
min=0.0,
default_value=0.5,
),
"pixel_threshold": NumericalValue(
description="Pixel threshold",
min=0.0,
default_value=0.5,
),
"normalization_scale": NumericalValue(
description="Value used for normalization",
),
"task": StringValue(
description="Task type",
default_value="segmentation",
),
"labels": ListValue(description="List of class labels", value_type=str),
},
)
parameters.update(ParameterRegistry.ANOMALY)
parameters.update(ParameterRegistry.LABELS)
return parameters

def _normalize(self, tensor: np.ndarray, threshold: float) -> np.ndarray:
"""Currently supports only min-max normalization."""
normalized = ((tensor - threshold) / self.normalization_scale) + 0.5
normalized = ((tensor - threshold) / self.params.normalization_scale) + 0.5
return np.clip(normalized, 0, 1)

@staticmethod
Expand Down
Loading
Loading