Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion edge_orchestrator/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ ARG BUILDOS
RUN if [ "$TARGETPLATFORM" = "linux/arm64" ] && [ "$BUILDOS" = "linux" ]; then \
apt update && apt install -y --no-install-recommends gnupg; \
echo "deb http://archive.raspberrypi.org/debian/ bookworm main" > /etc/apt/sources.list.d/raspi.list \
&& apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 82B129927FA3303E; \
&& gpg --keyserver keyserver.ubuntu.com --recv-keys 82B129927FA3303E \
&& gpg --export 82B129927FA3303E | tee /etc/apt/trusted.gpg.d/raspberry.gpg > /dev/null; \
fi

RUN if [ "$TARGETPLATFORM" = "linux/arm64" ] && [ "$BUILDOS" = "linux" ]; then \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


class ImageExtension(str, Enum):
bmp = "bmp"
jpeg = "jpeg"
jpg = "jpg"
png = "png"
tiff = "tiff"
BMP = "bmp"
JPEG = "jpeg"
JPG = "jpg"
PNG = "png"
TIFF = "tiff"
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@


class ClassifPrediction(BaseModel):
prediction_type: Literal[PredictionType.class_]
prediction_type: Literal[PredictionType.CLASS_]
label: Optional[str] = None
probability: Annotated[Optional[float], AfterValidator(round_float)] = None
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@


class DetectionPrediction(BaseModel):
prediction_type: Literal[PredictionType.objects]
prediction_type: Literal[PredictionType.OBJECTS]
detected_objects: Dict[str, DetectedObject] = Field(default=dict())
label: Optional[Decision] = None
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_class_names_path(self):

@model_validator(mode="after")
def check_class_names_or_class_names_path(self):
if self.model_type in [ModelType.classification, ModelType.object_detection] and (
if self.model_type in [ModelType.CLASSIFICATION, ModelType.OBJECT_DETECTION] and (
(not self.class_names and not self.class_names_filepath) or (self.class_names and self.class_names_filepath)
):
raise ValueError("Either class_names or class_names_path is required (exclusive)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


class ModelName(str, Enum):
fake_model = "fake_model"
marker_quality_control = "marker_quality_control"
pin_detection = "pin_detection"
mobilenet_ssd_v2_coco = "mobilenet_ssd_v2_coco"
mobilenet_ssd_v2_face = "mobilenet_ssd_v2_face"
yolo_coco_nano = "yolo_coco_nano"
FAKE_MODEL = "fake_model"
MARKER_QUALITY_CONTROL = "marker_quality_control"
PIN_DETECTION = "pin_detection"
MOBILENET_SSD_V2_COCO = "mobilenet_ssd_v2_coco"
MOBILENET_SSD_V2_FACE = "mobilenet_ssd_v2_face"
YOLO_COCO_NANO = "yolo_coco_nano"
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


class ModelType(str, Enum):
classification = "classification"
object_detection = "object_detection"
segmentation = "segmentation"
CLASSIFICATION = "classification"
OBJECT_DETECTION = "object_detection"
SEGMENTATION = "segmentation"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class PredictionType(str, Enum):
class_ = "class"
objects = "objects"
probability = "probability"
mask = "mask"
CLASS_ = "class"
OBJECTS = "objects"
PROBABILITY = "probability"
MASK = "mask"
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self, camera_rule_config: CameraRuleConfig):

def _get_camera_decision(self, prediction: Prediction) -> Decision:
classif = prediction
if classif.prediction_type != PredictionType.class_:
if classif.prediction_type != PredictionType.CLASS_:
self._logger.warning(
f"You can not use an ExpectedLabelRule on something other than {PredictionType.class_.value}, "
f"You can not use an ExpectedLabelRule on something other than {PredictionType.CLASS_.value}, "
"no decision returned."
)
return Decision.NO_DECISION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, camera_rule_config: CameraRuleConfig):

def _get_camera_decision(self, prediction: Prediction) -> Decision:
detec_predict_with_classif = prediction
if detec_predict_with_classif.prediction_type != PredictionType.objects:
if detec_predict_with_classif.prediction_type != PredictionType.OBJECTS:
self._logger.warning(
"You can not use an MaxNbObjectsRule on something other than "
f"{PredictionType.objects.value}, no decision returned."
f"{PredictionType.OBJECTS.value}, no decision returned."
)
return Decision.NO_DECISION

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, camera_rule_config: CameraRuleConfig):

def _get_camera_decision(self, prediction: Prediction) -> Decision:
detec_predict_with_classif = prediction
if detec_predict_with_classif.prediction_type != PredictionType.objects:
if detec_predict_with_classif.prediction_type != PredictionType.OBJECTS:
self._logger.warning(
"You can not use an MinNbObjectsRule on something other than "
f"{PredictionType.objects.value}, no decision returned."
f"You can not use a MinNbObjectsRule on something other than "
f"{PredictionType.OBJECTS.value}, no decision returned."
)
return Decision.NO_DECISION

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, camera_rule_config: CameraRuleConfig):

def _get_camera_decision(self, prediction: Prediction) -> Decision:
classif = prediction
if classif.prediction_type != PredictionType.class_:
if classif.prediction_type != PredictionType.CLASS_:
self._logger.warning(
"You can not use an ExpectedLabelRule on something other than "
f"{PredictionType.class_.value}, no decision returned."
f"{PredictionType.CLASS_.value}, no decision returned."
)
return Decision.NO_DECISION
if classif.label is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def _predict(self, preprocessed_binary: np.ndarray) -> Dict[str, Any]:
def _post_process_prediction(self, prediction_response: Dict[str, Any]) -> Prediction:
if len(prediction_response["outputs"]) == 0:
self._logger.warning("No predictions found")
return ClassifPrediction(prediction_type=PredictionType.class_)
return ClassifPrediction(prediction_type=PredictionType.CLASS_)

predictions = prediction_response["outputs"][0]
number_predictions_classes = len(predictions)
Expand All @@ -54,7 +54,7 @@ def _post_process_prediction(self, prediction_response: Dict[str, Any]) -> Predi
"the number of predictions ({number_predictions_classes})"
)
return ClassifPrediction(
prediction_type=PredictionType.class_,
prediction_type=PredictionType.CLASS_,
label=class_names[np.argmax(predictions)],
probability=float(np.max(predictions)),
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def _pre_process_binary(self, binary: bytes) -> np.ndarray:

async def _predict(self, preprocessed_binary: np.ndarray) -> Dict[str, Any]:
model_type = self._model_forwarder_config.model_type
if model_type == ModelType.classification:
if model_type == ModelType.CLASSIFICATION:
return {"label": random.choice(["OK", "KO"]), "probability": random.uniform(0, 1)}
elif model_type == ModelType.object_detection:
elif model_type == ModelType.OBJECT_DETECTION:
return {
"detected_objects": {
"object_1": {
Expand All @@ -56,25 +56,25 @@ async def _predict(self, preprocessed_binary: np.ndarray) -> Dict[str, Any]:

def _post_process_prediction(self, prediction_response: Dict[str, Any]) -> Prediction:
model_type = self._model_forwarder_config.model_type
if model_type == ModelType.classification:
if model_type == ModelType.CLASSIFICATION:
return ClassifPrediction(
prediction_type=PredictionType.class_,
prediction_type=PredictionType.CLASS_,
label=prediction_response["label"],
probability=prediction_response["probability"],
)
elif model_type == ModelType.object_detection:
elif model_type == ModelType.OBJECT_DETECTION:
detected_objects = prediction_response["detected_objects"]
return DetectionPrediction(
prediction_type=PredictionType.objects,
prediction_type=PredictionType.OBJECTS,
detected_objects={
"object_1": DetectedObject(
prediction_type=ModelType.classification,
prediction_type=ModelType.CLASSIFICATION,
label=detected_objects["object_1"]["label"],
location=detected_objects["object_1"]["location"],
objectness=detected_objects["object_1"]["objectness"],
),
"object_2": DetectedObject(
prediction_type=ModelType.classification,
prediction_type=ModelType.CLASSIFICATION,
label=detected_objects["object_2"]["label"],
location=detected_objects["object_2"]["location"],
objectness=detected_objects["object_2"]["objectness"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ def __init__(self):
self._logger = logging.getLogger(__name__)

def create_model_forwarder(self, model_forwarder_config: ModelForwarderConfig) -> IModelForwarder:
if model_forwarder_config.model_name == ModelName.fake_model:
if model_forwarder_config.model_name == ModelName.FAKE_MODEL:
from edge_orchestrator.infrastructure.adapters.model_forwarder.fake_model_forwarder import (
FakeModelForwarder,
)

return FakeModelForwarder(model_forwarder_config)
elif model_forwarder_config.model_type == ModelType.classification:
elif model_forwarder_config.model_type == ModelType.CLASSIFICATION:
from edge_orchestrator.infrastructure.adapters.model_forwarder.classif_model_forwarder import (
ClassifModelForwarder,
)

return ClassifModelForwarder(model_forwarder_config)

elif model_forwarder_config.model_type == ModelType.object_detection:
elif model_forwarder_config.model_type == ModelType.OBJECT_DETECTION:
from edge_orchestrator.infrastructure.adapters.model_forwarder.object_detection_model_forwarder import (
ObjectDetectionModelForwarder,
)

return ObjectDetectionModelForwarder(model_forwarder_config)

elif model_forwarder_config.model_type == ModelType.segmentation:
elif model_forwarder_config.model_type == ModelType.SEGMENTATION:
from edge_orchestrator.infrastructure.adapters.model_forwarder.segmentation_model_forwarder import (
SegmentationModelForwarder,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _pre_process_binary(self, binary: bytes) -> np.ndarray:
async def _predict(self, preprocessed_binary: np.ndarray) -> Dict[str, Any]:
# TODO: refactor edge_model_serving to remove model_type from the request
model_type = None
if self._model_forwarder_config.model_name == ModelName.yolo_coco_nano:
if self._model_forwarder_config.model_name == ModelName.YOLO_COCO_NANO:
model_type = "yolo"
async with aiohttp.ClientSession() as session:
async with session.post(
Expand All @@ -62,7 +62,7 @@ def _post_process_prediction(self, prediction_response: Dict[str, Any]) -> Predi
or len(predictions["detection_classes"]) == 0
):
self._logger.warning("No detected objects found!")
return DetectionPrediction(prediction_type=PredictionType.objects, detected_objects={})
return DetectionPrediction(prediction_type=PredictionType.OBJECTS, detected_objects={})

detected_objects = {}
boxes_coordinates, objectness_scores, detection_classes = (
Expand All @@ -80,4 +80,4 @@ def _post_process_prediction(self, prediction_response: Dict[str, Any]) -> Predi
objectness=box_objectness,
label=class_names[int(detection_classes[box_index])],
)
return DetectionPrediction(prediction_type=PredictionType.objects, detected_objects=detected_objects)
return DetectionPrediction(prediction_type=PredictionType.OBJECTS, detected_objects=detected_objects)
2 changes: 2 additions & 0 deletions edge_orchestrator/tests/helpers/tf_serving_container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Dict

from testcontainers.core.container import DockerContainer
Expand Down Expand Up @@ -28,4 +29,5 @@ def _connect(self, default_starting_log: str):
def start(self, starting_log: str = r"Uvicorn running on"):
super().start()
self._connect(starting_log)
time.sleep(2) # wait for the container to be fully started
return self
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class TestClassifModelForwarder:
@pytest.mark.parametrize(
"model_name,probability",
[
(ModelName.marker_quality_control, 0.83054),
(ModelName.pin_detection, 0.99962),
(ModelName.MARKER_QUALITY_CONTROL, 0.83054),
(ModelName.PIN_DETECTION, 0.99962),
],
)
async def test_classif_model_forwarder_should_return_classif_prediction(
Expand All @@ -39,7 +39,7 @@ async def test_classif_model_forwarder_should_return_classif_prediction(
image_resolution = ImageResolution(width=224, height=224)
model_forward_config = ModelForwarderConfig(
model_name=model_name,
model_type=ModelType.classification,
model_type=ModelType.CLASSIFICATION,
model_version="1",
class_names=["OK", "KO"],
model_serving_url=setup_test_tflite_serving,
Expand All @@ -48,7 +48,7 @@ async def test_classif_model_forwarder_should_return_classif_prediction(
model_fowarder = ClassifModelForwarder(model_forward_config)

expected_prediction = ClassifPrediction(
prediction_type=PredictionType.class_, label="KO", probability=probability
prediction_type=PredictionType.CLASS_, label="KO", probability=probability
)

# When
Expand All @@ -67,8 +67,8 @@ async def test_classif_model_forwarder_should_raise_exception_with_bad_url_provi
# Given
image_resolution = ImageResolution(width=224, height=224)
model_forward_config = ModelForwarderConfig(
model_name=ModelName.marker_quality_control,
model_type=ModelType.classification,
model_name=ModelName.MARKER_QUALITY_CONTROL,
model_type=ModelType.CLASSIFICATION,
model_version="1",
class_names=["OK", "KO"],
model_serving_url=setup_test_tflite_serving,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ async def test_model_forwarder_manager_should_return_no_prediction_with_bad_url(
camera_type=CameraType.FAKE,
source_directory="fake",
model_forwarder_config=ModelForwarderConfig(
model_name=ModelName.marker_quality_control,
model_type=ModelType.classification,
model_name=ModelName.MARKER_QUALITY_CONTROL,
model_type=ModelType.CLASSIFICATION,
model_serving_url="http://bad_url",
model_version="1",
class_names=["OK", "KO"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ class TestObjectDetectionModelForwarder:
"model_name,image_resolution,class_names,expected_number_of_objects",
[
(
ModelName.mobilenet_ssd_v2_coco,
ModelName.MOBILENET_SSD_V2_COCO,
{"width": 300, "height": 300},
["person", "bicycle"] * 50,
20,
),
(
ModelName.mobilenet_ssd_v2_face,
ModelName.MOBILENET_SSD_V2_FACE,
{"width": 320, "height": 320},
["person", "bicycle"] * 50,
50,
),
(
ModelName.yolo_coco_nano,
ModelName.YOLO_COCO_NANO,
{"width": 320, "height": 320},
["person", "bicycle"] * 50,
5,
Expand All @@ -56,7 +56,7 @@ async def test_object_detection_model_forwarder_should_return_detection_predicti
image_resolution = ImageResolution(**image_resolution)
model_forward_config = ModelForwarderConfig(
model_name=model_name,
model_type=ModelType.object_detection,
model_type=ModelType.OBJECT_DETECTION,
model_version="1",
class_names=class_names,
model_serving_url=setup_test_tflite_serving,
Expand Down Expand Up @@ -87,8 +87,8 @@ async def test_object_detection_model_forwarder_should_raise_exception_with_bad_
# Given
image_resolution = ImageResolution(width=224, height=224)
model_forward_config = ModelForwarderConfig(
model_name=ModelName.mobilenet_ssd_v2_coco,
model_type=ModelType.object_detection,
model_name=ModelName.MOBILENET_SSD_V2_COCO,
model_type=ModelType.OBJECT_DETECTION,
model_version="1",
class_names=["person", "bicycle"],
model_serving_url=setup_test_tflite_serving,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_model_fowarder_config_should_raise_exception_of_class_names_path_does_n
# When
with pytest.raises(ValidationError) as e:
ModelForwarderConfig(
model_name=ModelName.marker_quality_control,
model_type=ModelType.classification,
model_name=ModelName.MARKER_QUALITY_CONTROL,
model_type=ModelType.CLASSIFICATION,
class_names_filepath=unexisting_path,
expected_image_resolution=ImageResolution(width=224, height=224),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def test_camera_rule_manager(
),
},
predictions={
"camera_#1": ClassifPrediction(prediction_type=PredictionType.class_, label="OK", probability=0.41),
"camera_#2": ClassifPrediction(prediction_type=PredictionType.class_, label="KO", probability=0.96),
"camera_#1": ClassifPrediction(prediction_type=PredictionType.CLASS_, label="OK", probability=0.41),
"camera_#2": ClassifPrediction(prediction_type=PredictionType.CLASS_, label="KO", probability=0.96),
"camera_#3": DetectionPrediction(
prediction_type=PredictionType.objects,
prediction_type=PredictionType.OBJECTS,
detected_objects={
"object_#1": DetectedObject(location=[1, 2, 3, 4], objectness=0.6578, label="bike"),
"object_#2": DetectedObject(location=[1, 2, 3, 4], objectness=0.6578, label="moto"),
},
),
"camera_#4": DetectionPrediction(
prediction_type=PredictionType.objects,
prediction_type=PredictionType.OBJECTS,
detected_objects={
"object_#1": DetectedObject(location=[1, 2, 3, 4], objectness=0.6578, label="bike"),
"object_#2": DetectedObject(location=[1, 2, 3, 4], objectness=0.6578, label="moto"),
Expand Down
Loading
Loading