Skip to content
Merged
29 changes: 21 additions & 8 deletions app/api/endpoints/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,43 @@
import redis.asyncio as redis
import uuid
from datetime import datetime
import json

router = APIRouter()

def get_redis_client(request: Request) -> redis.Redis:
return request.app.state.redis_client

"""
테스트를 위한 임시적인 API
"""
@router.post("/predict", status_code=202)
async def create_prediction_job(job_request: JobRequest, redis_client: redis.Redis = Depends(get_redis_client)):
correlation_id = str(uuid.uuid4())
correlationId = str(uuid.uuid4())
job = ImageJob(
correlationId=correlation_id,
presignedUrl=job_request.presigned_url,
correlationId=correlationId,
presignedUrl=job_request.presignedUrl,
replyQueue=settings.STREAM_RESULT,
callbackUrl=None,
contentType="image/jpeg",
contentType=job_request.contentType,
createdAt=datetime.utcnow().isoformat(),
ttlSec=3600,
)

await redis_client.xadd(
# 타입에 맞도록 넣어주기
correlationId = job.correlationId
payload = json.dumps(job.dict())
print("분석 결과 발행 시작...")
entry_id = await redis_client.xadd(
settings.STREAM_JOB,
{"json": job.model_dump_json()},
{
"type": "image_results",

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

제대로 저장한 거 맞음

"payload": payload,
"correlationId": correlationId,
},
maxlen=10_000,
approximate=True,
)

return {"job_id": correlation_id}
print("분석 결과 발행 완료...")

return {"job_id": correlationId}
31 changes: 18 additions & 13 deletions app/schemas/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
from pydantic import BaseModel, Field

class JobRequest(BaseModel):
presigned_url: str = Field(..., description="다운로드할 이미지의 Presigned URL")
presignedUrl: str = Field(..., description="다운로드할 이미지의 Presigned URL")

"""
FastAPI가 image.results Stream에 발행하는 메시지
"""
class JobResult(BaseModel):
pill_name: str
correlation_id: str
label: str
confidence: float
finished_at: str
correlationId: str
pillName: str
isSafe: int
description: str
finishedAt: str

"""
Spring으로부터 받는 Job(image.jobs 구독)
"""
class ImageJob(BaseModel):
correlation_id: str = Field(alias="correlationId")
presigned_url: str = Field(alias="presignedUrl")
reply_queue: str = Field(alias="replyQueue")
callback_url: str | None = Field(alias="callbackUrl")
content_type: str = Field(alias="contentType")
created_at: str = Field(alias="createdAt")
ttl_sec: int = Field(alias="ttlSec")
correlationId: str = Field(alias="correlationId")
presignedUrl: str = Field(alias="presignedUrl")
replyQueue: str = Field(alias="replyQueue")
contentType: str = Field(alias="contentType")
createdAt: str = Field(alias="createdAt")
ttlSec: int = Field(alias="ttlSec")
141 changes: 103 additions & 38 deletions app/worker/redis_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import redis
import json

from redis.asyncio import Redis as AsyncRedis
from pydantic import BaseModel, Field
from typing import Any, Dict, Optional, List
from typing import Any, Dict, List, Optional, Union
from app.core.config import settings

# Redis Stream Definition
Expand All @@ -12,7 +12,6 @@ class PublishRequest(BaseModel):
payload: Dict[str, Any]

class RedisStreamClient:

def __init__(self):
self.redis_client = AsyncRedis.from_url(
url=settings.REDIS_URL,
Expand All @@ -21,38 +20,74 @@ def __init__(self):

@classmethod
def init(cls):
broker = cls()
return broker
return cls()

@staticmethod
def _to_scalar(v: Any) -> Union[str, bytes, int, float]:
# Redis XADD : str/bytes/int/float 허용
if isinstance(v, (str, bytes, int, float)):
return v
# 그 외는 JSON 문자열로 직렬화 (ensure_ascii=False로 한글 보존)
return json.dumps(v, ensure_ascii=False)


@classmethod
def _sanitize_fields_for_xadd(cls, fields: Dict[str, Any]) -> Dict[str, Union[str, bytes, int, float]]:
clean: Dict[str, Union[str, bytes, int, float]] = {}
for k, v in fields.items():
if not isinstance(k, str):
k = str(k)
clean[k] = cls._to_scalar(v)
return clean


# Fast API 에서 Publish
async def xadd(self, stream_name: str, fields: Dict[str, Any]) -> str:
return await self.redis_client.xadd(stream_name, fields)
async def xadd(
self,
stream_name: str,
fields: Dict[str, Any],
*,
maxlen: Optional[int] = 10_000,
approximate: bool = True,
nomkstream: bool = False,
id: str = "*",
) -> str:
safe_fields = self._sanitize_fields_for_xadd(fields)
return await self.redis_client.xadd(
stream_name,
safe_fields,
id=id,
maxlen=maxlen,
approximate=approximate,
nomkstream=nomkstream,
)

# Group 단위로 읽어오기
async def xreadgroup(
self,
group_name: str,
consumer_name: str,
stream_name: str,
count: Optional[int] = None,
block: Optional[int] = None, # ms 단위
id: str = ">", # 새 메시지만 읽기
self,
group_name: str,
consumer_name: str,
stream_name: str,
count: Optional[int] = None,
block: Optional[int] = None, # ms
) -> List[tuple]:
# Consumer Group 생성 (없으면)
try:
# Create the consumer group (존재하지 않을 때)
self.redis_client.xgroup_create(
stream_name, group_name, id="$", mkstream=True
await self.redis_client.xgroup_create(
name=stream_name,
groupname=group_name,
id="0",
mkstream=True,
)
except redis.exceptions.ResponseError as e:
# 이미 존재할 때
if "BUSYGROUP" not in str(e):
raise

streams = {stream_name: id}
streams = {stream_name: ">"} # 신규 메시지
response = await self.redis_client.xreadgroup(
group_name, # groupname (positional)
consumer_name, # consumername (positional)
streams, # {stream: id} (positional)
group_name,
consumer_name,
streams,
count=count,
block=block,
)
Expand All @@ -63,36 +98,38 @@ async def xack(self, stream_name: str, group_name: str, message_ids: List[str])
return await self.redis_client.xack(stream_name, group_name, *message_ids)

# 완료 시 삭제
async def xack_and_del(self, stream_name: str, group_name: str, message_ids: str) -> int:

acked_count = await self.redis_client.xack(stream_name, group_name, *message_ids)

# XACK가 성공하면 스트림에서 해당 메시지를 삭제 (XDEL)
async def xack_and_del(
self,
stream_name: str,
group_name: str,
message_ids: Union[str, List[str]],
) -> int:
ids = [message_ids] if isinstance(message_ids, str) else list(message_ids)
acked_count = await self.redis_client.xack(stream_name, group_name, *ids)
if acked_count > 0:
await self.redis_client.xdel(stream_name, *message_ids)

await self.redis_client.xdel(stream_name, *ids)
return acked_count

# Group 생성
async def xgroup_create(self, stream_name: str, group_name: str, id: str = "$") -> bool:
try:
self.redis_client.xgroup_create(stream_name, group_name, id, mkstream=True)
await self.redis_client.xgroup_create(stream_name, group_name, id, mkstream=True)
return True
except redis.exceptions.ResponseError as e:
if "BUSYGROUP" in str(e):
print(f"Consumer group '{group_name}' already exists.")
return False
raise e
raise

# 메시지 재처리 지원
async def xclaim(
self,
stream_name: str,
group_name: str,
consumer_name: str,
min_idle_time: int,
message_ids: List[str],
self,
stream_name: str,
group_name: str,
consumer_name: str,
min_idle_time: int,
message_ids: List[str],
) -> List[tuple]:

return await self.redis_client.xclaim(
stream_name=stream_name,
group_name=group_name,
Expand All @@ -101,6 +138,34 @@ async def xclaim(
message_ids=message_ids,
)

# 자동으로 재청구
async def xautoclaim(
self,
name: str,
groupname: str,
consumername: str,
min_idle_time: int,
start_id: str = "0-0",
count: Optional[int] = None,
justid: bool = False,
):
res = await self.redis_client.xautoclaim(
name=name,
groupname=groupname,
consumername=consumername,
min_idle_time=min_idle_time,
start_id=start_id,
count=count,
justid=justid,
)
# 2-튜플/3-튜플 호환하도록 전처리
if isinstance(res, (list, tuple)) and len(res) == 3:
next_id, messages, _deleted = res
return next_id, messages
return res


# 종료
async def aclose(self):
await self.redis_client.close()

Expand Down
47 changes: 33 additions & 14 deletions app/worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,53 @@
from app.services.predictor_service import predictor_service
from app.services.s3_service import s3_service

"""
이미지를 다운 -> 다운 한 것에 대하여 모델 분석 요청
"""
async def process_image_scan(job: ImageJob, redis_client: redis.Redis):
correlation_id = job.correlationId
print(f"[task] Start image scan for job_id={correlation_id}")
correlationId = job.correlationId
print(f"[task] Start image scan for job_id={correlationId}")
try:
stream_file = await asyncio.to_thread(s3_service.download_file_from_presigned_url(job.presignedUrl))

pill_name, label, confidence = await asyncio.to_thread(predictor_service.predict(stream_file))
stream_file = await asyncio.to_thread(
s3_service.download_file_from_presigned_url,
job.presignedUrl
)

stream_file.seek(0)

pillName, label, confidence = await asyncio.to_thread(
predictor_service.predict,
stream_file
)

# TODO: ChatGPT에 요청 결과 출력

finished_at = datetime.utcnow().isoformat()
isSafe = 0
description = "일단은 테스트입니다. 추후에 GPT 부분 추가할 예정"
finishedAt = datetime.utcnow().isoformat()

result = JobResult(
pill_name=pill_name,
correlation_id=correlation_id,
label=label,
confidence=confidence,
finished_at=finished_at,
correlationId=correlationId,
pillName=pillName,
isSafe=isSafe,
description=description,
finishedAt=finishedAt,
)

await redis_client.xadd(
settings.STREAM_RESULT,
{"json": result.model_dump_json()},
{
"correlationId": correlationId,
"type": "image_results",
"payload": result.model_dump_json()},
maxlen=10_000,
approximate=True,
)

print(f"[task] Image scan finished for job_id={correlation_id}")
print(f"[task] Image scan successfully finished for job_id={correlationId}")

except Exception as e:
print(f"[task] Failed to process job_id={correlation_id}: {e}")
print(f"[task] Failed to process job_id={correlationId}: {e}")
finally:
print(f"[task] Image scan finished for job_id={correlation_id}")
print(f"[task] Image scan finished for job_id={correlationId}")
Loading