-
Notifications
You must be signed in to change notification settings - Fork 2
Add onnx serving #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add onnx serving #123
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| FROM python:3.8-slim | ||
|
|
||
| WORKDIR /app | ||
|
|
||
| # Install dependencies | ||
| RUN apt-get update && apt-get upgrade -y && apt-get install gcc -y | ||
|
|
||
|
|
||
| COPY src /app/src | ||
| COPY pyproject.toml /app/ | ||
| COPY models/onnx /models/onnx | ||
|
|
||
|
|
||
| # dependances | ||
| RUN pip install --upgrade pip | ||
| RUN pip install -e "/app[linux]" | ||
|
|
||
|
|
||
| EXPOSE 8000 | ||
|
|
||
| ENV PYTHONPATH=/app/src | ||
| ENTRYPOINT ["uvicorn"] | ||
| CMD ["onnx_server:app", "--reload", "--host", "0.0.0.0", "--port", "8000"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| [build-system] | ||
| requires = ["setuptools>=66", "setuptools_scm[toml]>=6.2"] | ||
| build-backend = "setuptools.build_meta" | ||
|
|
||
| [project] | ||
| name = "onnx_serving" | ||
| version = "1.0.0" | ||
| dependencies = [ | ||
| "fastapi==0.92.0", | ||
| "numpy==1.24.2", | ||
| "Pillow==9.4.0", | ||
| "uvicorn==0.20.0", | ||
| "onnxruntime==1.16.0", | ||
| "opencv-python-headless==4.7.0.72", | ||
| "anyio==3.7.1" | ||
| ] | ||
| requires-python = ">=3.8" | ||
|
|
||
| [tool.setuptools.packages.find] | ||
| where = ["src/"] | ||
|
|
||
| [project.optional-dependencies] | ||
| dev = [ | ||
| "black==23.3.0", | ||
| "isort==5.13.2", | ||
| "flake8==7.1.1", | ||
| "autoflake==2.3.1", | ||
| "pytest==7.2.2", | ||
| "pytest-cov==4.0.0", | ||
| "requests==2.26.0" | ||
| ] | ||
|
|
||
| [tool.pytest.ini_options] | ||
| min_version = "6.0" | ||
| testpaths = "tests/" | ||
|
|
||
| [tool.flake8] | ||
| exclude = "venv*" | ||
| max-complexity = 10 | ||
| max-line-length = 120 | ||
|
|
||
| [tool.isort] | ||
| profile = "black" |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,148 @@ | ||||||
| import logging | ||||||
| from typing import Any, AnyStr, Dict, List, Union, Tuple | ||||||
| import numpy as np | ||||||
| from fastapi import APIRouter, HTTPException, Request | ||||||
| import time | ||||||
| import os | ||||||
| from PIL import Image | ||||||
| from pathlib import Path | ||||||
|
|
||||||
|
|
||||||
| from utils.yolo11n_postprocessing import ( | ||||||
| compute_severities, | ||||||
| non_max_suppression, | ||||||
| yolo_extract_boxes_information, | ||||||
| ) | ||||||
|
|
||||||
| JSONObject = Dict[AnyStr, Any] | ||||||
| JSONArray = List[Any] | ||||||
| JSONStructure = Union[JSONArray, JSONObject] | ||||||
|
|
||||||
| api_router = APIRouter() | ||||||
|
|
||||||
|
|
||||||
| # Page initiale | ||||||
| @api_router.get("/") | ||||||
| async def info(): | ||||||
| return """Welcome to the onnx-server for VIO !!""" | ||||||
|
|
||||||
|
|
||||||
| # Lister les modèles ONNX chargés | ||||||
| @api_router.get("/models") | ||||||
| async def get_models(request: Request) -> List[str]: | ||||||
| return list( | ||||||
| request.app.state.model_interpreters.keys() | ||||||
| ) # Retourne les modèles disponibles | ||||||
|
|
||||||
|
|
||||||
| # Récupérer les métadonnées d’un modèle | ||||||
| @api_router.get("/models/{model_name}/versions/{model_version}/resolution") | ||||||
| async def get_model_metadata( | ||||||
| model_name: str, model_version: str, request: Request | ||||||
| ) -> Dict[str, Tuple]: | ||||||
| session = request.app.state.model_interpreters[model_name] | ||||||
| input_details = session.get_inputs() | ||||||
| return {"inputs_shape": input_details[0].shape} | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ce sont les seules metadata qu'on a sur le model ? |
||||||
|
|
||||||
|
|
||||||
| # Faire une prédictionPrediction | ||||||
|
|
||||||
|
|
||||||
| def load_test_image(path: str) -> np.ndarray: | ||||||
| if not os.path.exists(path): | ||||||
| raise FileNotFoundError(f"Image de test introuvable : {path}") | ||||||
| img = Image.open(path).convert("RGB") | ||||||
| img = img.resize((640, 640)) | ||||||
| arr = np.array(img, dtype=np.float32) / 255.0 # Normalization : [640,640,3] | ||||||
| arr = arr.transpose(2, 0, 1) # --> [3,640,640] | ||||||
| arr = np.expand_dims(arr, axis=0) # --> [1,3,640,640] : format accepté par YOLO | ||||||
| return arr | ||||||
|
|
||||||
|
|
||||||
| # Modifier l'endpoint pour que l'utilisateur puisse envoyer une image par la requete | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tu peux le faire stp |
||||||
| @api_router.post("/models/{model_name}/versions/{model_version}:predict") | ||||||
| async def predict_test_image(model_name: str, model_version: str, request: Request): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ajoute le type hint de retour stp |
||||||
| HERE = Path(__file__).resolve().parent | ||||||
|
|
||||||
| # verification modèle | ||||||
| if model_name not in request.app.state.model_interpreters: | ||||||
| raise HTTPException(status_code=404, detail="Modèle non trouvé") | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| session = request.app.state.model_interpreters[ | ||||||
| model_name | ||||||
| ] # recuperer la session d'inference du modèle | ||||||
|
|
||||||
| # charger l'image | ||||||
| try: | ||||||
| test_img_path = HERE / "data" / "test_img.jpg" | ||||||
| input_array = load_test_image(test_img_path) | ||||||
| except FileNotFoundError as e: | ||||||
| raise HTTPException(status_code=500, detail=str(e)) | ||||||
|
Comment on lines
+74
to
+79
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. à virer pour utiliser l'image passée en param de la requête |
||||||
|
|
||||||
| logging.info(f"Chargé image de test, forme finale {input_array.shape}") | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. en anglais stp |
||||||
|
|
||||||
| # inférence | ||||||
| try: | ||||||
| input_details = session.get_inputs() | ||||||
| ort_inputs = {input_details[0].name: input_array} | ||||||
| outputs = session.run(None, ort_inputs) | ||||||
| except Exception: | ||||||
| raise HTTPException(status_code=500, detail="Erreur d'inférence ONNX") | ||||||
|
|
||||||
| # Post‐processing | ||||||
| try: | ||||||
| outputs = outputs[0][0] | ||||||
| boxes, scores, class_ids = yolo_extract_boxes_information(outputs) | ||||||
| boxes, scores, class_ids = non_max_suppression(boxes, scores, class_ids) | ||||||
| severities = compute_severities(input_array[0], boxes) | ||||||
|
|
||||||
| prediction = { | ||||||
| "outputs": { | ||||||
| "detection_boxes": [boxes.tolist()], | ||||||
| "detection_classes": [class_ids.tolist()], | ||||||
| "detection_scores": [scores.tolist()], | ||||||
| "severities": [severities], | ||||||
| } | ||||||
| } | ||||||
| return prediction | ||||||
| except Exception: | ||||||
| raise HTTPException(status_code=500, detail="Erreur postprocessing") | ||||||
|
|
||||||
|
|
||||||
| @api_router.post("/models/{model_name}/performance") | ||||||
| async def model_performance(model_name: str, request: Request): | ||||||
|
|
||||||
| # Verif exsitence modele | ||||||
| if model_name not in request.app.state.model_interpreters: | ||||||
| raise HTTPException(status_code=404, detail="Modèle non trouvé") | ||||||
| session = request.app.state.model_interpreters[model_name] | ||||||
| input_details = session.get_inputs() | ||||||
|
|
||||||
| # get img | ||||||
| HERE = Path(__file__).resolve().parent | ||||||
| test_img_path = HERE / "data" / "test_img.jpg" | ||||||
| try: | ||||||
| input_array = load_test_image(test_img_path) | ||||||
| except FileNotFoundError as e: | ||||||
| raise HTTPException(status_code=500, detail=str(e)) | ||||||
|
|
||||||
| # Verif format attendu par YOLO | ||||||
| if input_array.shape != (1, 3, 640, 640): | ||||||
| raise HTTPException( | ||||||
| status_code=400, | ||||||
| detail=f"Les dimensions de l'input doivent être [1,3,640,640], got {input_array.shape}", | ||||||
| ) | ||||||
|
|
||||||
| # Inférence + mesure du temps | ||||||
| try: | ||||||
| ort_inputs = {input_details[0].name: input_array} | ||||||
| start = time.time() | ||||||
| _ = session.run(None, ort_inputs) | ||||||
| exec_time = time.time() - start | ||||||
| except Exception: | ||||||
| raise HTTPException(status_code=500, detail="Erreur d'inférence ONNX") | ||||||
|
|
||||||
| return { | ||||||
| "model_name": model_name, | ||||||
| "input_shape": input_array.shape, | ||||||
| "inference_time_sec": round(exec_time, 4), | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import onnxruntime as ort | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| MODELS_PATH = Path(os.getenv("MODELS_PATH", "/models")) | ||
|
|
||
| # chercehr tous les modèles .onnx | ||
| model_files = list(MODELS_PATH.rglob("*.onnx")) | ||
|
|
||
| if not model_files: | ||
| raise FileNotFoundError(f"Aucun modèle ONNX trouvé dans {MODELS_PATH}") | ||
|
|
||
| print(f"Modèles trouvés : {[str(m) for m in model_files]}") | ||
|
|
||
| # choisir par défaut le premier modèle onnx disponible | ||
| MODEL_PATH = str(model_files[0]) | ||
|
Comment on lines
+5
to
+16
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. c'est du debug ? si non -> à faire dans la classe |
||
|
|
||
| class ONNXModel: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. la classe et le nom du fichier ne sont pas cohérents -> renommer le fichier |
||
| def __init__(self): | ||
| """Initialiser ONNX Runtime avec tous les modèles trouvés""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mets les docstring en anglais stp |
||
| self.models = {} | ||
| model_files = list(MODELS_PATH.rglob("*.onnx")) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plutôt qu'une variable globale, tu devrais passer
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. de même que le magic string |
||
|
|
||
| if not model_files: | ||
| raise FileNotFoundError(f"Aucun modèle ONNX trouvé dans {MODELS_PATH}") | ||
|
|
||
| for model_path in model_files: | ||
| model_name = model_path.stem # Nom du modele sans extension | ||
| self.models[model_name] = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) # Création d'une session d'inference par modèle | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| print(f"Model {model_name} loaded from {model_path}") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. utilise plutôt le logger (que des print) |
||
|
|
||
|
|
||
|
|
||
|
|
||
| """" | ||
| def predict(self, model_name, input_array): | ||
| if model_name not in self.models: | ||
| raise ValueError(f"Model {model_name} not loaded") | ||
|
|
||
| session = self.models[model_name] | ||
| input_name = session.get_inputs()[0].name | ||
| outputs = session.run(None, {input_name: input_array}) | ||
| return outputs | ||
| """ | ||
|
Comment on lines
+36
to
+44
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. c'est du code mort ? si oui tu peux le supprimer stp |
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,21 @@ | ||||||
| from fastapi import FastAPI | ||||||
| import logging | ||||||
| from api_routes import api_router | ||||||
| from onnx_interpreter import ONNXModel | ||||||
|
|
||||||
| app = FastAPI() | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. il faut qu'on adapte le |
||||||
|
|
||||||
| # Charger les modèles ONNX au démarrage du serveur | ||||||
| @app.on_event("startup") | ||||||
| async def load_model(): | ||||||
| logging.info("Chargement des modèles ONNX...") | ||||||
| model = ONNXModel() # Charge tous les modèles | ||||||
| app.state.model_interpreters = model.models # stocke tous les modèles | ||||||
| logging.info(f"Modèles chargés : {list(model.models.keys())}") | ||||||
|
Comment on lines
+9
to
+14
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plutôt que faire une fonction qui initialise la classe ONNXModel et ensuite crée le server FastAPI car les Regarde là:
|
||||||
|
|
||||||
| app.include_router(api_router) | ||||||
|
|
||||||
| # health check du serveur ONNX | ||||||
| @app.get("/") | ||||||
| async def root(): | ||||||
| return {"message": "ONNX Model Serving is running!", "models": list(app.state.model_interpreters.keys())} | ||||||
|
Comment on lines
+19
to
+21
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. à déplacer côté |
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import numpy as np | ||
| import cv2 | ||
|
|
||
| # fct pour calculer score IOU | ||
| def compute_iou(box1, box2): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. il faudrait rajouter les types des params et de la valeur de retour stp |
||
| x1, y1, x2, y2 = box1 | ||
| x1g, y1g, x2g, y2g = box2 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pourquoi |
||
|
|
||
| xi1, yi1 = max(x1, x1g), max(y1, y1g) | ||
| xi2, yi2 = min(x2, x2g), min(y2, y2g) | ||
|
Comment on lines
+9
to
+10
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ici les |
||
|
|
||
| inter_width = max(0, xi2 - xi1) | ||
| inter_height = max(0, yi2 - yi1) | ||
| inter_area = inter_width * inter_height | ||
|
|
||
| box1_area = (x2 - x1) * (y2 - y1) | ||
| box2_area = (x2g - x1g) * (y2g - y1g) | ||
|
|
||
| union_area = box1_area + box2_area - inter_area | ||
|
|
||
| return inter_area / union_area if union_area > 0 else 0 | ||
|
|
||
|
|
||
| # fct pour extraire les boîtes, scores et classes depuis les sorties YOLO ONNX | ||
| def yolo_extract_boxes_information(outputs, confidence_threshold=0.5): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. les types entrée/sortie stp |
||
| print(f"📌 Debug - outputs.shape: {np.array(outputs).shape}") # ✅ Affiche la forme des outputs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tu peux utiliser |
||
|
|
||
| boxes = [] | ||
| scores = [] | ||
| class_ids = [] | ||
|
|
||
| for output in outputs: | ||
| print(f"📌 Debug - output: {output}") # ✅ Vérifie la structure de chaque ligne | ||
|
|
||
| # Assurer que chaque sortie a bien 6 valeurs (x1, y1, x2, y2, conf, class_id) | ||
| if len(output) < 6: | ||
| print(f"⚠️ Warning : Une sortie YOLO ne contient que {len(output)} valeurs : {output}") | ||
| continue # Ignore cette sortie | ||
|
|
||
| x1, y1, x2, y2, conf, class_id = output[:6] # ✅ Prendre seulement les 6 premières valeurs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qu'est-ce qu'il y a dans les autres valeurs ? de 7 à ?? |
||
|
|
||
| if conf > confidence_threshold: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on ignore si la confidence est trop faible -> je trouve ça dommage si jamais on veut faire du monitoring du modèle on perd l'info comme ça. Je mettrais plutôt un flag pour dire qu'on ignore la prédiction plutôt que la rejeter |
||
| boxes.append([int(x1), int(y1), int(x2), int(y2)]) | ||
| scores.append(float(conf)) | ||
| class_ids.append(int(class_id)) | ||
|
Comment on lines
+43
to
+45
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pourquoi caster les variables ? le type venant de Yolo est string ? |
||
|
|
||
| return np.array(boxes), np.array(scores), np.array(class_ids) | ||
|
|
||
|
|
||
| # fct pour appliquer la suppression non maximale (NMS) aux boîtes détectées | ||
| def non_max_suppression(boxes, scores, class_ids, iou_threshold=0.3): | ||
| indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), 0.5, iou_threshold) | ||
| filtered_boxes = [] | ||
| filtered_scores = [] | ||
| filtered_class_ids = [] | ||
|
|
||
| for i in indices.flatten(): | ||
| filtered_boxes.append(boxes[i]) | ||
| filtered_scores.append(scores[i]) | ||
| filtered_class_ids.append(class_ids[i]) | ||
|
|
||
| return np.array(filtered_boxes), np.array(filtered_scores), np.array(filtered_class_ids) | ||
|
Comment on lines
+50
to
+62
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tu peux mettre un lien vers la doc pour ceux qui ne savent pas ce qu'est la NMS stp ? |
||
|
|
||
| # 📌 Fonction pour calculer les "severities" (peut être modifié selon besoin) | ||
| def compute_severities(frame, boxes): | ||
| severities = [] | ||
| for box in boxes: | ||
| severity = np.random.uniform(0, 1) # Exemple : générer un score aléatoire | ||
| severities.append(severity) | ||
| return severities | ||
|
Comment on lines
+65
to
+70
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. je ne comprends pas bien l'intérêt de générer des valeurs aléatoires. Quelle est l'idée ? |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
à utiliser dans la route de prédiction pour passer l'image comme param