Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions equivalence_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _initialize_model(self):
raise

def health_check(self) -> bool:
"""Check if the model is properly loaded and can perform inference."""
"""Check if the model is initialized and can perform inference."""
if not self.initialized:
logger.error("Model not initialized")
return False
Expand All @@ -102,16 +102,14 @@ def classify_texts(self, premise: str, hypothesis: str) -> Dict[str, Any]:
"""Classify the semantic relationship between two texts.

Args:
premise: The first text
hypothesis: The second text to compare with the premise
premise: The first text.
hypothesis: The second text to compare with the premise.

Returns:
A dictionary containing:
- entailment_score: Probability that hypothesis follows from premise
- contradiction_score: Probability that hypothesis contradicts premise
- neutral_score: Probability that hypothesis is neutral to premise
- predicted_label: The predicted relationship
- is_equivalent: Boolean indicating if texts are semantically equivalent
A dictionary containing the scores for entailment, contradiction, and
neutrality,
the predicted relationship label, and a boolean indicating semantic
equivalence.
"""
if not self.initialized:
raise RuntimeError("Model not initialized. Cannot perform classification.")
Expand Down Expand Up @@ -163,7 +161,13 @@ def classify_texts(self, premise: str, hypothesis: str) -> Dict[str, Any]:
raise

def shutdown(self):
"""Clean up resources."""
"""Shut down the model service and clean up resources.

This function logs the shutdown process of the model service. It checks if the
model is not None and, if the device type is 'cuda', it attempts to clear the
CUDA memory cache. Any errors encountered during the shutdown process are
logged for debugging purposes.
"""
try:
logger.info("Shutting down model service")
if self.model is not None:
Expand Down
23 changes: 22 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

# Register shutdown function to clean up the model service
def shutdown_model_service():
"""Shut down the model services and log the process."""
logger.info("Shutting down model services")
try:
model_service.shutdown()
Expand All @@ -101,6 +102,16 @@ def shutdown_model_service():
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: verify model is properly loaded
"""FastAPI lifespan context manager for application startup and shutdown.

This function manages the lifecycle of a FastAPI application by performing
health checks on the model services during startup and logging relevant
information. It verifies the functionality of both the similarity and
equivalence model services, handling any errors gracefully to ensure the
application continues running with potentially limited functionality.
Additionally, it logs memory usage statistics if the `psutil` library is
available.
"""
logger.info("Application startup - Verifying model services are properly loaded...")
try:
# Check similarity model service health
Expand Down Expand Up @@ -205,6 +216,7 @@ async def lifespan(app: FastAPI):
# Request logging middleware
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log incoming HTTP requests and their processing time."""
request_id = f"req-{int(time.time() * 1000)}"
logger.info(f"Request started [ID: {request_id}] - {request.method} {request.url.path}")

Expand All @@ -223,7 +235,14 @@ async def log_requests(request: Request, call_next):
# Add a health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint with detailed metrics."""
"""Health check endpoint that provides detailed metrics.

This endpoint performs a health check on various services, including the
similarity and equivalence model services, and gathers system metrics such as
memory and CPU usage. It logs the status of each service and the overall health
status, returning a structured response that includes timestamps, version
information, and any warnings related to system performance.
"""
logger.info("Health check endpoint called")

health_status = {
Expand Down Expand Up @@ -336,6 +355,7 @@ class CompareRequest(BaseModel):
# Route to compare sentences
@app.post("/compare")
async def compare_sentences(data: CompareRequest):
"""Compare two sentences and calculate their semantic similarity."""
request_id = f"req-{int(time.time() * 1000)}"
logger.info(f"Compare endpoint called [ID: {request_id}]")
logger.info(f"Sentence 1 ({len(data.sentence1)} chars): {data.sentence1[:50]}...")
Expand Down Expand Up @@ -369,6 +389,7 @@ class EquivalenceRequest(BaseModel):
# Route to check semantic equivalence
@app.post("/classify-equivalence")
async def classify_equivalence(data: EquivalenceRequest):
"""Classify the semantic equivalence of a premise and hypothesis."""
request_id = f"req-{int(time.time() * 1000)}"
logger.info(f"Equivalence endpoint called [ID: {request_id}]")
logger.info(f"Premise ({len(data.premise)} chars): {data.premise[:50]}...")
Expand Down
13 changes: 10 additions & 3 deletions model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, model_name, cache_folder, device=None):
raise

def encode(self, text):
"""Encode text using the model."""
"""Encode text using the model and log the encoding time."""
try:
start_time = time.time()
embedding = self.model.encode(text, convert_to_tensor=True)
Expand Down Expand Up @@ -120,7 +120,14 @@ def calculate_similarity(self, text1, text2):
raise RuntimeError(f"Error calculating similarity: {e}")

def health_check(self):
"""Check if the model is healthy by running a simple test."""
"""Check if the model is healthy by running a simple test.

This function performs a series of checks to ensure that the model is properly
loaded and functional. It verifies the existence of the model, checks for the
presence of the 'encode' method, and attempts a simple encoding operation. If
any of these checks fail, an error is logged, and the function returns False.
Otherwise, it logs a success message and returns True.
"""
try:
logger.info("Running health check")

Expand All @@ -147,7 +154,7 @@ def health_check(self):
return False

def shutdown(self):
"""Clean up resources."""
"""Clean up resources and shut down the model service."""
logger.info("Shutting down model service")

# Clear model reference to help with garbage collection
Expand Down