diff --git a/equivalence_model_service.py b/equivalence_model_service.py index ef3edf1..da8160f 100644 --- a/equivalence_model_service.py +++ b/equivalence_model_service.py @@ -79,7 +79,7 @@ def _initialize_model(self): raise def health_check(self) -> bool: - """Check if the model is properly loaded and can perform inference.""" + """Check if the model is initialized and can perform inference.""" if not self.initialized: logger.error("Model not initialized") return False @@ -102,16 +102,14 @@ def classify_texts(self, premise: str, hypothesis: str) -> Dict[str, Any]: """Classify the semantic relationship between two texts. Args: - premise: The first text - hypothesis: The second text to compare with the premise - + premise: The first text. + hypothesis: The second text to compare with the premise. + Returns: - A dictionary containing: - - entailment_score: Probability that hypothesis follows from premise - - contradiction_score: Probability that hypothesis contradicts premise - - neutral_score: Probability that hypothesis is neutral to premise - - predicted_label: The predicted relationship - - is_equivalent: Boolean indicating if texts are semantically equivalent + A dictionary containing the scores for entailment, contradiction, and + neutrality, + the predicted relationship label, and a boolean indicating semantic + equivalence. """ if not self.initialized: raise RuntimeError("Model not initialized. Cannot perform classification.") @@ -163,7 +161,13 @@ def classify_texts(self, premise: str, hypothesis: str) -> Dict[str, Any]: raise def shutdown(self): - """Clean up resources.""" + """Shut down the model service and clean up resources. + + This function logs the shutdown process of the model service. It checks if the + model is not None and, if the device type is 'cuda', it attempts to clear the + CUDA memory cache. Any errors encountered during the shutdown process are + logged for debugging purposes. + """ try: logger.info("Shutting down model service") if self.model is not None: diff --git a/main.py b/main.py index 78e37fa..4751346 100644 --- a/main.py +++ b/main.py @@ -83,6 +83,7 @@ # Register shutdown function to clean up the model service def shutdown_model_service(): + """Shut down the model services and log the process.""" logger.info("Shutting down model services") try: model_service.shutdown() @@ -101,6 +102,16 @@ def shutdown_model_service(): @asynccontextmanager async def lifespan(app: FastAPI): # Startup: verify model is properly loaded + """FastAPI lifespan context manager for application startup and shutdown. + + This function manages the lifecycle of a FastAPI application by performing + health checks on the model services during startup and logging relevant + information. It verifies the functionality of both the similarity and + equivalence model services, handling any errors gracefully to ensure the + application continues running with potentially limited functionality. + Additionally, it logs memory usage statistics if the `psutil` library is + available. + """ logger.info("Application startup - Verifying model services are properly loaded...") try: # Check similarity model service health @@ -205,6 +216,7 @@ async def lifespan(app: FastAPI): # Request logging middleware @app.middleware("http") async def log_requests(request: Request, call_next): + """Log incoming HTTP requests and their processing time.""" request_id = f"req-{int(time.time() * 1000)}" logger.info(f"Request started [ID: {request_id}] - {request.method} {request.url.path}") @@ -223,7 +235,14 @@ async def log_requests(request: Request, call_next): # Add a health check endpoint @app.get("/health") async def health_check(): - """Health check endpoint with detailed metrics.""" + """Health check endpoint that provides detailed metrics. + + This endpoint performs a health check on various services, including the + similarity and equivalence model services, and gathers system metrics such as + memory and CPU usage. It logs the status of each service and the overall health + status, returning a structured response that includes timestamps, version + information, and any warnings related to system performance. + """ logger.info("Health check endpoint called") health_status = { @@ -336,6 +355,7 @@ class CompareRequest(BaseModel): # Route to compare sentences @app.post("/compare") async def compare_sentences(data: CompareRequest): + """Compare two sentences and calculate their semantic similarity.""" request_id = f"req-{int(time.time() * 1000)}" logger.info(f"Compare endpoint called [ID: {request_id}]") logger.info(f"Sentence 1 ({len(data.sentence1)} chars): {data.sentence1[:50]}...") @@ -369,6 +389,7 @@ class EquivalenceRequest(BaseModel): # Route to check semantic equivalence @app.post("/classify-equivalence") async def classify_equivalence(data: EquivalenceRequest): + """Classify the semantic equivalence of a premise and hypothesis.""" request_id = f"req-{int(time.time() * 1000)}" logger.info(f"Equivalence endpoint called [ID: {request_id}]") logger.info(f"Premise ({len(data.premise)} chars): {data.premise[:50]}...") diff --git a/model_service.py b/model_service.py index 7eba181..e3e7545 100644 --- a/model_service.py +++ b/model_service.py @@ -80,7 +80,7 @@ def __init__(self, model_name, cache_folder, device=None): raise def encode(self, text): - """Encode text using the model.""" + """Encode text using the model and log the encoding time.""" try: start_time = time.time() embedding = self.model.encode(text, convert_to_tensor=True) @@ -120,7 +120,14 @@ def calculate_similarity(self, text1, text2): raise RuntimeError(f"Error calculating similarity: {e}") def health_check(self): - """Check if the model is healthy by running a simple test.""" + """Check if the model is healthy by running a simple test. + + This function performs a series of checks to ensure that the model is properly + loaded and functional. It verifies the existence of the model, checks for the + presence of the 'encode' method, and attempts a simple encoding operation. If + any of these checks fail, an error is logged, and the function returns False. + Otherwise, it logs a success message and returns True. + """ try: logger.info("Running health check") @@ -147,7 +154,7 @@ def health_check(self): return False def shutdown(self): - """Clean up resources.""" + """Clean up resources and shut down the model service.""" logger.info("Shutting down model service") # Clear model reference to help with garbage collection