diff --git a/README.md b/README.md index 5c5f16e..49755e5 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ uvx semgrep-mcp # see --help for more options Or, run as a [Docker container](https://ghcr.io/semgrep/mcp): ```bash -docker run -i --rm ghcr.io/semgrep/mcp -t stdio +docker run -i --rm ghcr.io/semgrep/mcp -t stdio ``` ### Cursor @@ -473,7 +473,7 @@ async def main(): { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/examples/sse_client.py b/examples/sse_client.py index d93edd9..c0c1243 100644 --- a/examples/sse_client.py +++ b/examples/sse_client.py @@ -18,7 +18,7 @@ async def main(): { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/examples/streamable_http_client.py b/examples/streamable_http_client.py index e4259b7..388395e 100644 --- a/examples/streamable_http_client.py +++ b/examples/streamable_http_client.py @@ -19,7 +19,7 @@ async def main(): { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/src/semgrep_mcp/models.py b/src/semgrep_mcp/models.py index a648870..fa76a09 100644 --- a/src/semgrep_mcp/models.py +++ b/src/semgrep_mcp/models.py @@ -4,12 +4,13 @@ from pydantic import BaseModel, Field, HttpUrl -class LocalCodeFile(BaseModel): - path: str = Field(description="Absolute path to be scanned locally by Semgrep.") - - class CodeFile(BaseModel): - filename: str = Field(description="Relative path to the code file") + # This "path" is mostly for bookkeeping purposes. + # Depending on whether the server is hosted or not, this path might + # not actually exist on the filesystem. + path: str = Field(description="Path of the code file") + # The `content` field will be filled in either by the LLM (in the remote scanning case) + # or gleaned from the filesystem (in the local scanning case). content: str = Field(description="Content of the code file") diff --git a/src/semgrep_mcp/semgrep.py b/src/semgrep_mcp/semgrep.py index a3cdb3b..20cd96a 100644 --- a/src/semgrep_mcp/semgrep.py +++ b/src/semgrep_mcp/semgrep.py @@ -275,7 +275,9 @@ async def run_semgrep_via_rpc(context: SemgrepContext, data: list[CodeFile]) -> List of CliMatch objects """ - files_json = [{"file": data.filename, "content": data.content} for data in data] + # TODO: to be honest it's silly for us to wire the contents of the files over RPC + # if they exist on the local filesystem, we could just pass the paths + files_json = [{"file": data.path, "content": data.content} for data in data] # ATD serialized value resp = await context.send_request("scanFiles", files=files_json) diff --git a/src/semgrep_mcp/server.py b/src/semgrep_mcp/server.py index c521d8f..146705d 100755 --- a/src/semgrep_mcp/server.py +++ b/src/semgrep_mcp/server.py @@ -24,7 +24,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse -from semgrep_mcp.models import CodeFile, Finding, LocalCodeFile, SemgrepScanResult +from semgrep_mcp.models import CodeFile, Finding, SemgrepScanResult from semgrep_mcp.semgrep import ( SemgrepContext, mk_context, @@ -33,7 +33,6 @@ ) from semgrep_mcp.semgrep_interfaces.semgrep_output_v1 import CliOutput from semgrep_mcp.utilities.tracing import ( - attach_metrics, attach_rpc_scan_metrics, attach_scan_metrics, start_tracing, @@ -42,6 +41,7 @@ from semgrep_mcp.utilities.utils import ( get_semgrep_app_token, get_semgrep_version, + is_hosted, set_semgrep_executable, ) from semgrep_mcp.version import __version__ @@ -54,7 +54,7 @@ SEMGREP_API_VERSION = "v1" # Field definitions for function parameters -CODE_FILES_FIELD = Field(description="List of dictionaries with 'filename' and 'content' keys") +REMOTE_CODE_FILES_FIELD = Field(description="List of dictionaries with 'path' and 'content' keys") LOCAL_CODE_FILES_FIELD = Field( description=("List of dictionaries with 'path' pointing to the absolute path of the code file") ) @@ -104,7 +104,7 @@ def safe_join(base_dir: str, untrusted_path: str) -> str: # Ensure untrusted path is not absolute # This is soft validation, path traversal is checked later - if os.path.isabs(untrusted_path): + if Path(untrusted_path).is_absolute(): raise ValueError("Untrusted path must be relative") # Join and normalize the untrusted path @@ -120,7 +120,7 @@ def safe_join(base_dir: str, untrusted_path: str) -> str: # Path validation def validate_absolute_path(path_to_validate: str, param_name: str) -> str: """Validates an absolute path to ensure it's safe to use""" - if not os.path.isabs(path_to_validate): + if not Path(path_to_validate).is_absolute(): raise McpError( ErrorData( code=INVALID_PARAMS, @@ -174,7 +174,7 @@ def create_temp_files_from_code_content(code_files: list[CodeFile]) -> str: # Create files in the temporary directory for file_info in code_files: - filename = file_info.filename + filename = file_info.path if not filename: continue @@ -227,12 +227,12 @@ def get_semgrep_scan_args(temp_dir: str, config: str | None = None) -> list[str] return args -def validate_local_files(local_files: list[dict[str, str]]) -> list[LocalCodeFile]: +def validate_local_files(local_files: list[dict[str, str]]) -> list[CodeFile]: """ Validates the local_files parameter for semgrep scan using Pydantic validation Args: - local_files: List of LocalCodeFile objects + local_files: List of singleton dictionaries with a "path" key Raises: McpError: If validation fails @@ -245,28 +245,37 @@ def validate_local_files(local_files: list[dict[str, str]]) -> list[LocalCodeFil ) try: # Pydantic will automatically validate each item in the list - validated_local_files = [LocalCodeFile.model_validate(file) for file in local_files] + validated_local_files = [] + for file in local_files: + path = file["path"] + if not Path(path).is_absolute(): + raise McpError( + ErrorData( + code=INVALID_PARAMS, message="code_files.path must be a absolute path" + ) + ) + contents = Path(path).read_text() + # We need to not use the absolute path here, as there is logic later + # that raises, to prevent path traversal. + # In reality, the name of the file is pretty immaterial. We only + # want the accurate path insofar as we can get the contents (whcih we do here) + # and so we can remember what original file it corresponds to. + # Taking the name of the file should be enough. + validated_local_files.append(CodeFile(path=Path(path).name, content=contents)) except Exception as e: raise McpError( - ErrorData(code=INVALID_PARAMS, message=f"Invalid code files format: {e!s}") + ErrorData(code=INVALID_PARAMS, message=f"Invalid local code files format: {e!s}") ) from e - for file in validated_local_files: - if not os.path.isabs(file.path): - raise McpError( - ErrorData( - code=INVALID_PARAMS, message="code_files.filename must be a absolute path" - ) - ) return validated_local_files -def validate_code_files(code_files: list[dict[str, str]]) -> list[CodeFile]: +def validate_remote_files(code_files: list[dict[str, str]]) -> list[CodeFile]: """ Validates the code_files parameter for semgrep scan using Pydantic validation Args: - code_files: List of CodeFile objects + code_files: List of dictionaries with a "path" and "content" key Raises: McpError: If validation fails @@ -280,19 +289,12 @@ def validate_code_files(code_files: list[dict[str, str]]) -> list[CodeFile]: try: # Pydantic will automatically validate each item in the list validated_code_files = [CodeFile.model_validate(file) for file in code_files] + + return validated_code_files except Exception as e: raise McpError( - ErrorData(code=INVALID_PARAMS, message=f"Invalid code files format: {e!s}") + ErrorData(code=INVALID_PARAMS, message=f"Invalid remote code files format: {e!s}") ) from e - for file in validated_code_files: - if os.path.isabs(file.filename): - raise McpError( - ErrorData( - code=INVALID_PARAMS, message="code_files.filename must be a relative path" - ) - ) - - return validated_code_files def remove_temp_dir_from_results(results: SemgrepScanResult, temp_dir: str) -> None: @@ -618,7 +620,7 @@ async def semgrep_findings( @with_tool_span() async def semgrep_scan_with_custom_rule( ctx: Context, - code_files: list[dict[str, str]] = CODE_FILES_FIELD, + code_files: list[dict[str, str]] = REMOTE_CODE_FILES_FIELD, rule: str = RULE_FIELD, ) -> SemgrepScanResult: """ @@ -630,7 +632,7 @@ async def semgrep_scan_with_custom_rule( - scan code files for specific issue not covered by the default Semgrep rules """ # Validate code_files - validated_code_files = validate_code_files(code_files) + validated_code_files = validate_remote_files(code_files) temp_dir = None try: # Create temporary files from code content @@ -667,12 +669,77 @@ async def semgrep_scan_with_custom_rule( shutil.rmtree(temp_dir, ignore_errors=True) +@mcp.tool() +@with_tool_span() +async def get_abstract_syntax_tree( + ctx: Context, + code: str = Field(description="The code to get the AST for"), + language: str = Field(description="The programming language of the code"), +) -> str: + """ + Returns the Abstract Syntax Tree (AST) for the provided code file in JSON format + + Use this tool when you need to: + - get the Abstract Syntax Tree (AST) for the provided code file\ + - get the AST of a file + - understand the structure of the code in a more granular way + - see what a parser sees in the code + """ + temp_dir = None + temp_file_path = "" + try: + # Create temporary directory and file for AST generation + temp_dir = tempfile.mkdtemp(prefix="semgrep_ast_") + temp_file_path = os.path.join(temp_dir, "code.txt") # safe + + # Write content to file + with open(temp_file_path, "w") as f: + f.write(code) + + args = [ + "--experimental", + "--dump-ast", + "-l", + language, + "--json", + temp_file_path, + ] + return await run_semgrep_output(top_level_span=None, args=args) + + except McpError as e: + raise e + except ValidationError as e: + raise McpError( + ErrorData(code=INTERNAL_ERROR, message=f"Error parsing semgrep output: {e!s}") + ) from e + except OSError as e: + raise McpError( + ErrorData( + code=INTERNAL_ERROR, + message=f"Failed to create or write to file {temp_file_path}: {e!s}", + ) + ) from e + except Exception as e: + raise McpError( + ErrorData(code=INTERNAL_ERROR, message=f"Error running semgrep scan: {e!s}") + ) from e + finally: + if temp_dir: + # Clean up temporary files + shutil.rmtree(temp_dir, ignore_errors=True) + + +# --------------------------------------------------------------------------------- +# Scanning tools +# --------------------------------------------------------------------------------- + + @with_tool_span() async def semgrep_scan_cli( ctx: Context, code_files: list[CodeFile], config: str | None = CONFIG_FIELD, -) -> SemgrepScanResult | CliOutput: +) -> SemgrepScanResult: """ Runs a Semgrep scan on provided code content and returns the findings in JSON format @@ -685,6 +752,7 @@ async def semgrep_scan_cli( - scan code files for security vulnerabilities - scan code files for other issues """ + # Validate config config = validate_config(config) @@ -755,34 +823,23 @@ async def semgrep_scan_rpc( shutil.rmtree(temp_dir, ignore_errors=True) -@mcp.tool() -@with_tool_span() -async def semgrep_scan( +async def semgrep_scan_core( ctx: Context, - code_files: list[dict[str, str]] = CODE_FILES_FIELD, - # TODO: currently only for CLI-based scans + code_files: list[CodeFile], config: str | None = CONFIG_FIELD, ) -> SemgrepScanResult | CliOutput: """ - Runs a Semgrep scan on provided code content and returns the findings in JSON format - - Use this tool when you need to: - - scan code files for security vulnerabilities - - scan code files for other issues - """ + Runs a Semgrep scan on provided CodeFile objects and returns the findings in JSON format - # Implementer's note: - # Depending on whether `USE_SEMGREP_RPC` is set, this tool will either run a `pysemgrep` - # CLI scan, or an RPC-based scan. - # Respectively, this will cause us to return either a `SemgrepScanResult` or a `CliOutput`. - # I put this here, in a comment, so the MCP doesn't need to be aware - # of these differences. + Depending on whether `USE_SEMGREP_RPC` is set, this tool will either run a `pysemgrep` + CLI scan, or an RPC-based scan. - validated_code_files = validate_code_files(code_files) + Respectively, this will cause us to return either a `SemgrepScanResult` or a `CliOutput`. + """ context: SemgrepContext = ctx.request_context.lifespan_context - paths = [cf.filename for cf in validated_code_files] + paths = [cf.path for cf in code_files] if context.process is not None: if config is not None: @@ -792,150 +849,72 @@ async def semgrep_scan( ErrorData( code=INVALID_PARAMS, message=""" - `config` is not supported when using the RPC-based scan. - Try calling again without that parameter set? - """, + `config` is not supported when using the RPC-based scan. + Try calling again without that parameter set? + """, ) ) logging.info(f"Running RPC-based scan on paths: {paths}") - return await semgrep_scan_rpc(ctx, validated_code_files) + return await semgrep_scan_rpc(ctx, code_files) else: logging.info(f"Running CLI-based scan on paths: {paths}") - return await semgrep_scan_cli(ctx, validated_code_files, config) + return await semgrep_scan_cli(ctx, code_files, config) @mcp.tool() @with_tool_span() -async def semgrep_scan_local( +async def semgrep_scan_remote( ctx: Context, - code_files: list[dict[str, str]] = LOCAL_CODE_FILES_FIELD, + code_files: list[dict[str, str]] = REMOTE_CODE_FILES_FIELD, + # TODO: currently only for CLI-based scans config: str | None = CONFIG_FIELD, -) -> list[SemgrepScanResult]: +) -> SemgrepScanResult | CliOutput: """ - Runs a Semgrep scan locally on provided code files returns the findings in JSON format. - - Files are expected to be in the current paths are absolute paths to the code files. + Runs a Semgrep scan on provided code content and returns the findings in JSON format Use this tool when you need to: - scan code files for security vulnerabilities - scan code files for other issues """ - import os - - if not os.environ.get("SEMGREP_ALLOW_LOCAL_SCAN"): - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=( - "Local Semgrep scans are not allowed unless SEMGREP_ALLOW_LOCAL_SCAN is set" - ), - ) - ) - # Validate config - config = validate_config(config) - validated_local_files = validate_local_files(code_files) - - temp_dir = None - try: - results, skipped_rules, scanned_paths, findings, errors = [], [], [], [], [] - for cf in validated_local_files: - args = get_semgrep_scan_args(cf.path, config) - output = await run_semgrep_output(top_level_span=None, args=args) - result: SemgrepScanResult = SemgrepScanResult.model_validate_json(output) - results.append(result) - skipped_rules.extend(result.skipped_rules) - scanned_paths.extend(result.paths["scanned"]) - findings.extend(result.results) - errors.extend(result.errors) - - attach_metrics( - get_current_span(), - results[0].version, - skipped_rules, - scanned_paths, - findings, - errors, - config, - ) - return results + # Implementer's note: + # This is one possible entry point for regular scanning, depending on whether + # the server is remotely hosted or not. + # If the server is hosted, only this tool will be available, and not the + # `semgrep_scan` tool. - except McpError as e: - raise e - except ValidationError as e: - raise McpError( - ErrorData(code=INTERNAL_ERROR, message=f"Error parsing semgrep output: {e!s}") - ) from e - except Exception as e: - raise McpError( - ErrorData(code=INTERNAL_ERROR, message=f"Error running semgrep scan: {e!s}") - ) from e + validated_code_files = validate_remote_files(code_files) - finally: - if temp_dir: - # Clean up temporary files - shutil.rmtree(temp_dir, ignore_errors=True) + return await semgrep_scan_core(ctx, validated_code_files, config) @mcp.tool() @with_tool_span() -async def get_abstract_syntax_tree( +async def semgrep_scan( ctx: Context, - code: str = Field(description="The code to get the AST for"), - language: str = Field(description="The programming language of the code"), -) -> str: + code_files: list[dict[str, str]] = LOCAL_CODE_FILES_FIELD, + config: str | None = CONFIG_FIELD, +) -> SemgrepScanResult | CliOutput: """ - Returns the Abstract Syntax Tree (AST) for the provided code file in JSON format + Runs a Semgrep scan locally on provided code files returns the findings in JSON format. + + Files are expected to be absolute paths to the code files. Use this tool when you need to: - - get the Abstract Syntax Tree (AST) for the provided code file\ - - get the AST of a file - - understand the structure of the code in a more granular way - - see what a parser sees in the code + - scan code files for security vulnerabilities + - scan code files for other issues """ - temp_dir = None - temp_file_path = "" - try: - # Create temporary directory and file for AST generation - temp_dir = tempfile.mkdtemp(prefix="semgrep_ast_") - temp_file_path = os.path.join(temp_dir, "code.txt") # safe - # Write content to file - with open(temp_file_path, "w") as f: - f.write(code) + # Implementer's note: + # This is one possible entry point for regular scanning, depending on whether + # the server is remotely hosted or not. + # If the server is local, only this tool will be available, and not the + # `semgrep_scan_remote` tool. - args = [ - "--experimental", - "--dump-ast", - "-l", - language, - "--json", - temp_file_path, - ] - return await run_semgrep_output(top_level_span=None, args=args) + validated_local_files = validate_local_files(code_files) - except McpError as e: - raise e - except ValidationError as e: - raise McpError( - ErrorData(code=INTERNAL_ERROR, message=f"Error parsing semgrep output: {e!s}") - ) from e - except OSError as e: - raise McpError( - ErrorData( - code=INTERNAL_ERROR, - message=f"Failed to create or write to file {temp_file_path}: {e!s}", - ) - ) from e - except Exception as e: - raise McpError( - ErrorData(code=INTERNAL_ERROR, message=f"Error running semgrep scan: {e!s}") - ) from e - finally: - if temp_dir: - # Clean up temporary files - shutil.rmtree(temp_dir, ignore_errors=True) + return await semgrep_scan_core(ctx, validated_local_files, config) # --------------------------------------------------------------------------------- @@ -1067,7 +1046,7 @@ async def health(request: Request) -> JSONResponse: "SEMGREP_FINDINGS_DISABLED": "semgrep_findings", "SEMGREP_SCAN_WITH_CUSTOM_RULE_DISABLED": "semgrep_scan_with_custom_rule", "SEMGREP_SCAN_DISABLED": "semgrep_scan", - "SEMGREP_SCAN_LOCAL_DISABLED": "semgrep_scan_local", + "SEMGREP_SCAN_REMOTE_DISABLED": "semgrep_scan_remote", "GET_ABSTRACT_SYNTAX_TREE_DISABLED": "get_abstract_syntax_tree", } @@ -1082,6 +1061,11 @@ def deregister_tools() -> None: # not stop us from doing so del mcp._tool_manager._tools[tool_name] + if is_hosted(): + del mcp._tool_manager._tools["semgrep_scan"] + else: + del mcp._tool_manager._tools["semgrep_scan_remote"] + # --------------------------------------------------------------------------------- # MCP Server Entry Point diff --git a/tests/integration/test_create_temp_files.py b/tests/integration/test_create_temp_files.py index c415983..88d5a8b 100644 --- a/tests/integration/test_create_temp_files.py +++ b/tests/integration/test_create_temp_files.py @@ -10,9 +10,9 @@ def test_create_temp_files_from_code_content(): """Test that create_temp_files_from_code_content correctly creates temp files with content""" # Define test code files code_files = [ - CodeFile(filename="test_file.py", content="print('Hello, world!')"), - CodeFile(filename="nested/path/test_file.js", content="console.log('Hello, world!');"), - CodeFile(filename="special chars/file with spaces.txt", content="Hello, world!"), + CodeFile(path="test_file.py", content="print('Hello, world!')"), + CodeFile(path="nested/path/test_file.js", content="console.log('Hello, world!');"), + CodeFile(path="special chars/file with spaces.txt", content="Hello, world!"), ] # Call the function @@ -26,7 +26,7 @@ def test_create_temp_files_from_code_content(): # Check if files were created with correct content for code_file in code_files: - file_path = os.path.join(temp_dir, code_file.filename) + file_path = os.path.join(temp_dir, code_file.path) assert os.path.exists(file_path) with open(file_path) as f: content = f.read() @@ -71,8 +71,8 @@ def test_create_temp_files_from_code_content_empty_list(): def test_create_temp_files_from_code_content_empty_filename(): """Test that create_temp_files_from_code_content handles empty filenames""" code_files = [ - CodeFile(filename="", content="This content should be skipped"), - CodeFile(filename="valid_file.txt", content="This is valid content"), + CodeFile(path="", content="This content should be skipped"), + CodeFile(path="valid_file.txt", content="This is valid content"), ] temp_dir = None @@ -111,9 +111,9 @@ def test_create_temp_files_from_code_content_path_traversal(): """Test that create_temp_files_from_code_content prevents path traversal""" # Define test code files with path traversal attempts code_files = [ - CodeFile(filename="../attempt_to_write_outside.txt", content="This should fail"), - CodeFile(filename="subdir/../../../etc/passwd", content="This should fail too"), - CodeFile(filename="/absolute/path/file.txt", content="This should fail as well"), + CodeFile(path="../attempt_to_write_outside.txt", content="This should fail"), + CodeFile(path="subdir/../../../etc/passwd", content="This should fail too"), + CodeFile(path="/absolute/path/file.txt", content="This should fail as well"), ] # The function should raise a ValueError for path traversal attempts diff --git a/tests/integration/test_local_scan.py b/tests/integration/test_local_scan.py new file mode 100644 index 0000000..28858f0 --- /dev/null +++ b/tests/integration/test_local_scan.py @@ -0,0 +1,62 @@ +import json +import os +import subprocess +import time +from pathlib import Path +from tempfile import NamedTemporaryFile + +import pytest +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +base_url = os.getenv("MCP_BASE_URL", "http://127.0.0.1:8000") + + +@pytest.fixture(scope="module") +def streamable_server(): + # Start the streamable-http server + proc = subprocess.Popen( + ["python", "src/semgrep_mcp/server.py", "-t", "streamable-http"], + ) + # Wait briefly to ensure the server starts + time.sleep(5) + yield + # Teardown: terminate the server + proc.terminate() + proc.wait() + + +@pytest.mark.asyncio +async def test_local_scan(streamable_server): + async with streamablehttp_client(f"{base_url}/mcp") as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initializing session... + await session.initialize() + # Session initialized + + with NamedTemporaryFile( + "w", prefix="hello_world", suffix=".py", encoding="utf-8" + ) as tmp: + tmp.write("def hello(): print('Hello, World!')") + tmp.flush() + + path = tmp.name + + # Scan code for security issues using local semgrep_scan + results = await session.call_tool( + "semgrep_scan", + { + "code_files": [ + { + "path": str(Path(path).absolute()), + } + ] + }, + ) + # We have results! + assert results is not None + content = json.loads(results.content[0].text) # type: ignore + assert isinstance(content, dict) + assert len(content["paths"]["scanned"]) == 1 + assert content["paths"]["scanned"][0].startswith("hello_world") + print(json.dumps(content, indent=2)) diff --git a/tests/integration/test_sse_client.py b/tests/integration/test_sse_client.py index 7b206a2..96f2c47 100644 --- a/tests/integration/test_sse_client.py +++ b/tests/integration/test_sse_client.py @@ -15,7 +15,10 @@ @pytest.fixture(scope="module") def sse_server(): # Start the SSE server - proc = subprocess.Popen(["python", "src/semgrep_mcp/server.py", "-t", "sse"]) + proc = subprocess.Popen( + ["python", "src/semgrep_mcp/server.py", "-t", "sse"], + env={"SEMGREP_IS_HOSTED": "true", **os.environ}, + ) # Wait briefly to ensure the server starts time.sleep(5) yield @@ -34,11 +37,11 @@ async def test_sse_client_smoke(sse_server): # Scan code for security issues results = await session.call_tool( - "semgrep_scan", + "semgrep_scan_remote", { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/tests/integration/test_stdio_client.py b/tests/integration/test_stdio_client.py index dd0e8d3..f744c29 100644 --- a/tests/integration/test_stdio_client.py +++ b/tests/integration/test_stdio_client.py @@ -1,4 +1,5 @@ import json +import os import pytest from mcp import ClientSession, StdioServerParameters @@ -8,7 +9,11 @@ server_params = StdioServerParameters( command="python", # Executable args=["src/semgrep_mcp/server.py"], # Optional command line arguments - env={"USE_SEMGREP_RPC": "false"}, # Optional environment variables + env={ + "USE_SEMGREP_RPC": "false", + "SEMGREP_IS_HOSTED": "true", + **os.environ, + }, # Optional environment variables ) @@ -39,11 +44,11 @@ async def test_stdio_client(): # Call a tool results = await session.call_tool( - "semgrep_scan", + "semgrep_scan_remote", { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/tests/integration/test_streamable_client.py b/tests/integration/test_streamable_client.py index 817e881..cafbbd8 100644 --- a/tests/integration/test_streamable_client.py +++ b/tests/integration/test_streamable_client.py @@ -13,9 +13,12 @@ @pytest.fixture(scope="module") def streamable_server(): # Start the streamable-http server - proc = subprocess.Popen(["python", "src/semgrep_mcp/server.py", "-t", "streamable-http"]) + proc = subprocess.Popen( + ["python", "src/semgrep_mcp/server.py", "-t", "streamable-http"], + env={"SEMGREP_IS_HOSTED": "true", **os.environ}, + ) # Wait briefly to ensure the server starts - time.sleep(2) + time.sleep(5) yield # Teardown: terminate the server proc.terminate() @@ -32,11 +35,11 @@ async def test_streamable_client_smoke(streamable_server): # Scan code for security issues results = await session.call_tool( - "semgrep_scan", + "semgrep_scan_remote", { "code_files": [ { - "filename": "hello_world.py", + "path": "hello_world.py", "content": "def hello(): print('Hello, World!')", } ] diff --git a/uv.lock b/uv.lock index 39385f2..f24a3d8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" [[package]]