diff --git a/src/auto_uv.py b/src/auto_uv.py index 2a1aca5..ff129d4 100644 --- a/src/auto_uv.py +++ b/src/auto_uv.py @@ -3,16 +3,39 @@ import subprocess -def should_use_uv(): +def _find_project_root(start_dir, max_depth=10): + """ + Walk up from start_dir looking for uv project markers. + + Returns the project root directory if found, None otherwise. + """ + check_dir = os.path.abspath(start_dir) + for _ in range(max_depth): + # Check for uv project markers + if (os.path.isfile(os.path.join(check_dir, "pyproject.toml")) or + os.path.isdir(os.path.join(check_dir, ".venv")) or + os.path.isfile(os.path.join(check_dir, "uv.lock"))): + return check_dir + + # Move up one directory + parent = os.path.dirname(check_dir) + if parent == check_dir: # Reached root + break + check_dir = parent + + return None + + +def should_use_uv(script_path=None): """Check if we should intercept and use uv run.""" # Don't intercept if we're already running under uv if os.environ.get("UV_RUN_ACTIVE"): return False - + # Don't intercept if AUTO_UV is explicitly disabled if os.environ.get("AUTO_UV_DISABLE", "").lower() in ("1", "true", "yes"): return False - + # Check if uv is available try: subprocess.run( @@ -23,30 +46,21 @@ def should_use_uv(): ) except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired): return False - - # Check if we're in a uv project (has pyproject.toml or .venv) - # This is important for the use case: "I'm in a project dir, run python script.py" - # We want to use uv run to pick up the project's environment - current_dir = os.getcwd() - - # Walk up the directory tree looking for project markers - check_dir = current_dir - max_depth = 10 # Don't search too far up - for _ in range(max_depth): - # Check for uv project markers - if (os.path.isfile(os.path.join(check_dir, "pyproject.toml")) or - os.path.isdir(os.path.join(check_dir, ".venv")) or - os.path.isfile(os.path.join(check_dir, "uv.lock"))): - return True - - # Move up one directory - parent = os.path.dirname(check_dir) - if parent == check_dir: # Reached root - break - check_dir = parent - - # No project markers found, don't intercept - return False + + # Check if we're in a uv project (has pyproject.toml or .venv or uv.lock) + # First, try to find project root from the script's directory (if provided) + # This handles the case: "cd /some/dir && python /path/to/project/script.py" + # Then fall back to current working directory for backwards compatibility + project_root = None + + if script_path: + script_dir = os.path.dirname(os.path.abspath(script_path)) + project_root = _find_project_root(script_dir) + + if not project_root: + project_root = _find_project_root(os.getcwd()) + + return project_root is not None def auto_use_uv(): @@ -118,7 +132,7 @@ def auto_use_uv(): if script_path.startswith(sys_dir + os.path.sep) or script_path == sys_dir: return - if should_use_uv(): + if should_use_uv(script_path): # Set environment variable to prevent infinite loop os.environ["UV_RUN_ACTIVE"] = "1" diff --git a/tests/test_auto_uv.py b/tests/test_auto_uv.py index 896f3f2..81cef2a 100644 --- a/tests/test_auto_uv.py +++ b/tests/test_auto_uv.py @@ -444,10 +444,75 @@ def test_no_interception_during_import(): os.unlink(script_path) +def test_script_path_detection(): + """ + Test that auto-uv detects projects based on script location, not just CWD. + + This covers the use case: "cd /some/dir && python /path/to/project/script.py" + auto-uv should detect the project from the script's directory, not from CWD. + """ + from auto_uv import should_use_uv, _find_project_root + + # Save original directory and environment + original_dir = os.getcwd() + original_env = os.environ.copy() + + try: + # Temporarily remove AUTO_UV_DISABLE to test project detection + if "AUTO_UV_DISABLE" in os.environ: + del os.environ["AUTO_UV_DISABLE"] + if "UV_RUN_ACTIVE" in os.environ: + del os.environ["UV_RUN_ACTIVE"] + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a project directory with pyproject.toml + project_dir = os.path.join(tmpdir, "my_project") + os.makedirs(project_dir) + with open(os.path.join(project_dir, "pyproject.toml"), "w") as f: + f.write("[project]\nname = 'test'\n") + + # Create a script inside the project + script_path = os.path.join(project_dir, "my_script.py") + with open(script_path, "w") as f: + f.write("print('hello')") + + # Create a completely unrelated directory (no project markers) + unrelated_dir = os.path.join(tmpdir, "unrelated") + os.makedirs(unrelated_dir) + + # Test 1: From unrelated dir, should_use_uv() without script_path returns False + os.chdir(unrelated_dir) + result = should_use_uv() + print(f"From unrelated dir, no script_path: should_use_uv = {result}") + assert result is False, "Should NOT detect project when CWD has no markers and no script_path" + + # Test 2: From unrelated dir, should_use_uv(script_path) returns True + result = should_use_uv(script_path) + print(f"From unrelated dir, with script_path: should_use_uv = {result}") + assert result is True, "Should detect project from script's directory even when CWD has no markers" + + # Test 3: _find_project_root from script directory finds the project + script_dir = os.path.dirname(script_path) + project_root = _find_project_root(script_dir) + print(f"_find_project_root from script dir: {project_root}") + assert project_root == project_dir, f"Expected {project_dir}, got {project_root}" + + # Test 4: _find_project_root from unrelated dir returns None + project_root = _find_project_root(unrelated_dir) + print(f"_find_project_root from unrelated dir: {project_root}") + assert project_root is None, "Should not find project root from unrelated directory" + + finally: + # Restore original directory and environment + os.chdir(original_dir) + os.environ.clear() + os.environ.update(original_env) + + def test_path_normalization(): """ Test that relative and absolute paths are normalized before comparison. - + This prevents the bug where __main__.__file__ is relative but sys.argv[0] is absolute (or vice versa), causing incorrect interception. """ @@ -548,4 +613,11 @@ def test_path_normalization(): except Exception as e: print(f"FAILED: {e}\n") + print("Test 10: Script path detection (detect project from script location)") + try: + test_script_path_detection() + print("PASSED\n") + except Exception as e: + print(f"FAILED: {e}\n") + print("Tests complete!")