diff --git a/.gitignore b/.gitignore index 843a75b..05446df 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ contrib .idea .env + +# Claude local settings +Claude.local.md \ No newline at end of file diff --git a/mcp_server/djangomcp.py b/mcp_server/djangomcp.py index 9470cdc..f2ed970 100644 --- a/mcp_server/djangomcp.py +++ b/mcp_server/djangomcp.py @@ -17,6 +17,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from rest_framework.mixins import CreateModelMixin, UpdateModelMixin, DestroyModelMixin, ListModelMixin from rest_framework.serializers import Serializer +from rest_framework.test import APIRequestFactory from starlette.datastructures import Headers from starlette.types import Scope, Receive, Send @@ -460,39 +461,35 @@ def init(): class _DRFRequestWrapper(HttpRequest): - def __init__(self, mcp_server, mcp_request, method, body_json=None, id=None): - super().__init__() - self._serialized_body = json.dumps(body_json).encode("utf-8") if body_json else b'' - self.method = method - self.content_type = "application/json" - self.META = { - 'CONTENT_TYPE': 'application/json', - 'HTTP_ACCEPT': 'application/json', - 'CONTENT_LENGTH': len(self._serialized_body) - } - - self._stream = BytesIO(self._serialized_body) - self._read_started = False - self.user = mcp_request.user - self.session = mcp_request.session - self.original_request = mcp_request - self.path = f'/_djangomcpserver/{mcp_server.name}' + def __new__(cls, mcp_server, mcp_request, method, body_json=None, id=None): + # Using APIRequestFactory ensures that all the attributes DRF is expecting are correctly set on the HttpRequest + factory = APIRequestFactory() + path = f'/_djangomcpserver/{mcp_server.name}' if id: - self.path += f"/{id}" + path += f"/{id}" + + if method == 'POST': + request = factory.post(path, body_json, format='json') + elif method == 'PUT': + request = factory.put(path, body_json, format='json') + elif method == 'DELETE': + request = factory.delete(path) + elif method == 'GET': + request = factory.get(path) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + if mcp_request.user: + request.user = mcp_request.user + if mcp_request.session: + request.session = mcp_request.session + + return request class BaseAPIViewCallerTool: view: Type["APIView"] - @staticmethod - def _patched_initialize_request(self, request, *args, **kwargs): - original_request = request.original_request - original_request.request = request - original_request.method = request.method - return original_request - def __init__(self, view_class, **kwargs): - view_class.initialize_request = self._patched_initialize_request self.view = view_class.as_view(**kwargs)