forked from JHZ5583233/Applied-ML-16
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFastAPI.py
More file actions
89 lines (70 loc) · 2.76 KB
/
FastAPI.py
File metadata and controls
89 lines (70 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import io
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
from project_name.models.cnn import CNNBackbone
from project_name.models.Preprocessing_class import Preprocessing
from numpy import array
app = FastAPI(title="Depth Prediction API",
description="Uploads an image and "
"returns a predicted depth map.")
# Setup
MODEL_PATH = "cnn_best.pth"
model = CNNBackbone(pretrained=False)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
preprocessor = Preprocessing(tile_size=(256, 256))
def process_image(file_bytes: bytes,
model: torch.nn.Module,
preprocessor: Preprocessing) -> io.BytesIO:
"""Function to process image
Args:
file_bytes (bytes): file in byte representation
model (torch.nn.Module): neural network model
preprocessor (Preprocessing): preprocessing class
Returns:
bytes: depth image in byte format
"""
img_array = preprocessor.load_image(io.BytesIO(file_bytes))
tiles = preprocessor.tile_with_padding(img_array)
depth_tiles = []
for tile in tiles:
input_tensor = preprocessor.to_tensor(tile).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
depth_tiles.append(output.squeeze().cpu().numpy())
depth_map = preprocessor.reconstruct_depth(array(depth_tiles))
depth_rgb = preprocessor.depth_to_rgb(depth_map, invert=True)
result_image = Image.fromarray(depth_rgb)
byte_io = io.BytesIO()
result_image.save(byte_io, format="PNG")
byte_io.seek(0)
return byte_io
@app.post("/predict_depth/", summary="Predict depth from image")
async def predict_depth(file: UploadFile = File(...)) -> StreamingResponse:
"""Function to generate depth from image given
Args:
file (UploadFile, optional): image file that is uploaded.
Defaults to File(...).
Raises:
HTTPException: image format is not supported
HTTPException: when there is an error processing the image
Returns:
StreamingResponse: response with processed image
"""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Invalid image format")
try:
contents = await file.read()
image_bytes = process_image(contents, model, preprocessor)
return StreamingResponse(image_bytes, media_type="image/png")
except Exception:
raise HTTPException(status_code=500, detail="Error processing image.")
@app.get("/", summary="Health check")
def read_root() -> dict:
"""Health check function
Returns:
dict: dict of the status
"""
return {"status": "healthy"}