11import traceback
2- from http import HTTPStatus # Add this import
2+ from http import HTTPStatus
33from typing import Callable , TypeVar
44
5- from starlette .background import BackgroundTasks
65from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
76from starlette .requests import Request
87from starlette .responses import JSONResponse , Response
98
109from codegen .runner .sandbox .runner import SandboxRunner
1110from codegen .shared .exceptions .compilation import UserCodeException
1211from codegen .shared .logging .get_logger import get_logger
13- from codegen .shared .performance .stopwatch_utils import stopwatch
1412
1513logger = get_logger (__name__ )
1614
@@ -34,13 +32,10 @@ async def dispatch(self, request: TRequest, call_next: RequestResponseEndpoint)
3432 return await call_next (request )
3533
3634 async def process_request (self , request : TRequest , call_next : RequestResponseEndpoint ) -> TResponse :
37- background_tasks = BackgroundTasks ()
3835 try :
3936 logger .info (f"> (CodemodRunMiddleware) Request: { request .url .path } " )
4037 self .runner .codebase .viz .clear_graphviz_data ()
4138 response = await call_next (request )
42- background_tasks .add_task (self .cleanup_after_codemod , is_exception = False )
43- response .background = background_tasks
4439 return response
4540
4641 except UserCodeException as e :
@@ -52,21 +47,4 @@ async def process_request(self, request: TRequest, call_next: RequestResponseEnd
5247 message = f"Unexpected error for { request .url .path } "
5348 logger .exception (message )
5449 res = JSONResponse (status_code = HTTPStatus .INTERNAL_SERVER_ERROR , content = {"detail" : message , "error" : str (e ), "traceback" : traceback .format_exc ()})
55- background_tasks .add_task (self .cleanup_after_codemod , is_exception = True )
56- res .background = background_tasks
5750 return res
58-
59- async def cleanup_after_codemod (self , is_exception : bool = False ):
60- if is_exception :
61- # TODO: instead of committing transactions, we should just rollback
62- logger .info ("Committing pending transactions due to exception" )
63- self .runner .codebase .ctx .commit_transactions (sync_graph = False )
64- await self .reset_runner ()
65-
66- @stopwatch
67- async def reset_runner (self ):
68- logger .info ("=====[ reset_runner ]=====" )
69- logger .info (f"Syncing runner to commit: { self .runner .commit } ..." )
70- self .runner .codebase .checkout (commit = self .runner .commit )
71- self .runner .codebase .clean_repo ()
72- self .runner .codebase .checkout (branch = self .runner .codebase .default_branch , create_if_missing = True )
0 commit comments