Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions gitshield/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""GitShield CLI — Prevent accidental secret commits."""

import re
import sys
from pathlib import Path

import click

from . import __version__
from .config import filter_findings, load_ignore_list, find_git_root
from .config import filter_findings, load_config, load_ignore_list, find_git_root
from .formatter import print_findings, print_json, print_blocked_message, colorize, Colors
from .scanner import scan_path, ScannerError

Expand All @@ -32,7 +33,8 @@ def scan(path: str, staged: bool, no_git: bool, as_json: bool, sarif: bool, quie

# Filter ignored
ignores = load_ignore_list(Path(path))
findings = filter_findings(findings, ignores)
config = load_config(Path(path))
findings = filter_findings(findings, ignores, config=config)

# Output
if sarif:
Expand Down Expand Up @@ -75,7 +77,7 @@ def hook_install(path: str):
hook_path = hooks_dir / "pre-commit"

# GitShield hook content
gitshield_hook = '\n\n# GitShield secret scan\nexport PATH="$PATH:$HOME/Library/Python/3.9/bin:$HOME/.local/bin"\ngitshield scan --staged --quiet\n'
gitshield_hook = '\n\n# GitShield secret scan\nexport PATH="$PATH:$HOME/.local/bin"\ngitshield scan --staged --quiet\n'

if hook_path.exists():
content = hook_path.read_text()
Expand All @@ -89,7 +91,7 @@ def hook_install(path: str):
hook_content = """#!/bin/sh
# GitShield pre-commit hook

export PATH="$PATH:$HOME/Library/Python/3.9/bin:$HOME/.local/bin"
export PATH="$PATH:$HOME/.local/bin"
gitshield scan --staged --quiet
"""
hook_path.write_text(hook_content)
Expand Down Expand Up @@ -119,16 +121,27 @@ def hook_uninstall(path: str):

lines = content.split("\n")
new_lines = []
skip_next = False
in_block = False

for line in lines:
if "# GitShield" in line:
skip_next = True
in_block = True
continue
if skip_next and "gitshield" in line:
skip_next = False
if in_block:
if not line.strip():
# Skip blank lines within the block
continue
if "gitshield" in line.lower():
# The gitshield command line — block ends after this
in_block = False
continue
if line.startswith("export"):
# PATH export added by gitshield block
continue
# Non-gitshield content encountered — end of block
in_block = False
new_lines.append(line)
continue
skip_next = False
new_lines.append(line)

new_content = "\n".join(new_lines).strip()
Expand Down Expand Up @@ -175,10 +188,15 @@ def claude_status():
@main.command()
@click.option("--path", "-p", default=".", type=click.Path(exists=True),
help="Repository path")
def init(path: str):
@click.option("--force", is_flag=True, help="Overwrite existing config")
def init(path: str, force: bool):
"""Create a .gitshield.toml config file with sensible defaults."""
from .config import create_default_config
config_path = create_default_config(Path(path))
try:
config_path = create_default_config(Path(path), force=force)
except FileExistsError as e:
click.echo(colorize(f"Error: {e}", Colors.RED), err=True)
sys.exit(1)
click.echo(colorize(f"Created {config_path}", Colors.GREEN))
click.echo("Edit this file to customize patterns, allowlists, and thresholds.")

Expand All @@ -192,7 +210,7 @@ def init(path: str):
@click.option("--stats", is_flag=True, help="Show scanning statistics")
def patrol(repo: str, limit: int, dry_run: bool, stats: bool):
"""Scan public GitHub repos for leaked secrets."""
from .monitor import fetch_public_events, fetch_repo_info, clone_and_scan
from .monitor import fetch_public_events, fetch_repo_info, clone_and_scan, GitHubError
from .notifier import notify
from .db import get_stats

Expand All @@ -209,6 +227,10 @@ def patrol(repo: str, limit: int, dry_run: bool, stats: bool):
click.echo(colorize("Error: Use format owner/name", Colors.RED), err=True)
sys.exit(1)
owner, name = repo.split("/", 1)
_valid_gh_name = re.compile(r'^[A-Za-z0-9._-]+$')
if not _valid_gh_name.match(owner) or not _valid_gh_name.match(name):
click.echo(colorize("Error: Invalid repo format", Colors.RED), err=True)
sys.exit(1)
repos = [fetch_repo_info(owner, name)]
click.echo(f"Scanning {repo}...")
else:
Expand Down Expand Up @@ -255,10 +277,10 @@ def patrol(repo: str, limit: int, dry_run: bool, stats: bool):
click.echo(f" Secrets found: {total_findings}")
click.echo(f" Notifications: {notified_count}")

except GitHubError as e:
click.echo(colorize(f"Error: {e}", Colors.RED), err=True)
sys.exit(1)
except Exception as e:
if "GitHubError" in type(e).__name__:
click.echo(colorize(f"Error: {e}", Colors.RED), err=True)
sys.exit(1)
click.echo(colorize(f"Error: {e}", Colors.RED), err=True)
sys.exit(2)

Expand Down
17 changes: 16 additions & 1 deletion gitshield/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import fnmatch
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
Expand Down Expand Up @@ -80,6 +81,14 @@ class GitShieldConfig:
custom_patterns: List[Dict[str, Any]] = field(default_factory=list)


# ---------------------------------------------------------------------------
# Credential helpers
# ---------------------------------------------------------------------------
def get_github_token() -> Optional[str]:
"""Get GitHub token from environment."""
return os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN")


# ---------------------------------------------------------------------------
# Git root discovery
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -175,14 +184,20 @@ def load_config(path: Path) -> GitShieldConfig:
# ---------------------------------------------------------------------------
# Default config creation
# ---------------------------------------------------------------------------
def create_default_config(path: Path) -> Path:
def create_default_config(path: Path, force: bool = False) -> Path:
"""
Create a .gitshield.toml with sensible defaults and inline comments.

Raises FileExistsError if the config already exists and *force* is False.

Returns the path to the created file.
"""
root = find_git_root(path)
config_file = root / CONFIG_FILE
if config_file.exists() and not force:
raise FileExistsError(
f"{config_file} already exists. Use --force to overwrite."
)
config_file.write_text(_DEFAULT_CONFIG_TOML, encoding="utf-8")
return config_file

Expand Down
53 changes: 38 additions & 15 deletions gitshield/db.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
"""SQLite database for tracking scanned repos and notifications."""

import atexit
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import List, Optional, Set

# Database location
DB_DIR = Path.home() / ".gitshield"
DB_PATH = DB_DIR / "gitshield.db"

# Module-level singleton connection — initialized on first use.
_conn: Optional[sqlite3.Connection] = None


def _close_connection() -> None:
"""Close the singleton connection on process exit."""
global _conn
if _conn is not None:
_conn.close()
_conn = None


atexit.register(_close_connection)


def get_connection() -> sqlite3.Connection:
"""Get database connection, creating tables if needed."""
DB_DIR.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
_init_tables(conn)
return conn
"""Return the module-level DB connection, creating it on first call."""
global _conn
if _conn is None:
DB_DIR.mkdir(parents=True, exist_ok=True)
_conn = sqlite3.connect(DB_PATH)
_conn.row_factory = sqlite3.Row
_init_tables(_conn)
return _conn


def _init_tables(conn: sqlite3.Connection) -> None:
Expand Down Expand Up @@ -51,7 +68,6 @@ def was_scanned_recently(repo_url: str, hours: int = 24) -> bool:
(repo_url,)
)
row = cursor.fetchone()
conn.close()

if not row:
return False
Expand All @@ -72,7 +88,6 @@ def mark_scanned(repo_url: str, findings_count: int = 0) -> None:
findings_count = excluded.findings_count
""", (repo_url, datetime.now().isoformat(), findings_count))
conn.commit()
conn.close()


def was_notified(repo_url: str, fingerprint: str) -> bool:
Expand All @@ -82,9 +97,7 @@ def was_notified(repo_url: str, fingerprint: str) -> bool:
"SELECT id FROM notifications WHERE repo_url = ? AND fingerprint = ?",
(repo_url, fingerprint)
)
result = cursor.fetchone() is not None
conn.close()
return result
return cursor.fetchone() is not None


def mark_notified(
Expand All @@ -101,7 +114,19 @@ def mark_notified(
VALUES (?, ?, ?, ?, ?)
""", (repo_url, email, fingerprint, datetime.now().isoformat(), method))
conn.commit()
conn.close()


def get_notified_fingerprints(repo_url: str, fingerprints: List[str]) -> Set[str]:
"""Return the subset of *fingerprints* that have already been notified."""
if not fingerprints:
return set()
conn = get_connection()
placeholders = ",".join("?" * len(fingerprints))
cursor = conn.execute(
f"SELECT fingerprint FROM notifications WHERE repo_url = ? AND fingerprint IN ({placeholders})",
(repo_url, *fingerprints),
)
return {row["fingerprint"] for row in cursor.fetchall()}


def get_stats() -> dict:
Expand All @@ -112,8 +137,6 @@ def get_stats() -> dict:
findings = conn.execute("SELECT SUM(findings_count) FROM scanned_repos").fetchone()[0] or 0
notifications = conn.execute("SELECT COUNT(*) FROM notifications").fetchone()[0]

conn.close()

return {
"repos_scanned": repos,
"total_findings": findings,
Expand Down
54 changes: 35 additions & 19 deletions gitshield/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
"""

import fnmatch
import os
import subprocess
from pathlib import Path
from typing import List, Set, Union
from typing import List, Optional, Set, Union

from .models import Finding
from .patterns import entropy, PATTERNS
from .scanner import Finding

# Directories to always skip during tree walks.
_SKIP_DIRS: Set[str] = {
Expand Down Expand Up @@ -115,6 +116,7 @@ def scan_text(
text: str,
filename: str = "<stdin>",
line_offset: int = 0,
config_threshold: Optional[float] = None,
) -> List[Finding]:
"""Scan a text string line-by-line against all patterns.

Expand Down Expand Up @@ -146,8 +148,12 @@ def scan_text(
ent = entropy(matched_text)
if ent < pattern.entropy_threshold:
continue
else:
elif config_threshold is not None:
ent = entropy(matched_text)
if ent < config_threshold:
continue
else:
ent = 0.0

line_number = idx + line_offset

Expand Down Expand Up @@ -224,40 +230,48 @@ def scan_directory(

findings: List[Finding] = []

for file_path in root.rglob("*"):
if not file_path.is_file():
continue
for dirpath, dirnames, filenames in os.walk(root):
# Prune skip directories in-place to prevent descending into them.
dirnames[:] = [d for d in dirnames if d not in _SKIP_DIRS]

if _should_skip_path(file_path):
continue
for filename in filenames:
file_path = Path(dirpath) / filename

# Gitignore filtering.
if ignore_patterns:
try:
rel = str(file_path.relative_to(root))
except ValueError:
rel = str(file_path)
if _matches_gitignore(rel, ignore_patterns):
if _should_skip_path(file_path):
continue

findings.extend(scan_file(file_path))
# Gitignore filtering.
if ignore_patterns:
try:
rel = str(file_path.relative_to(root))
except ValueError:
rel = str(file_path)
if _matches_gitignore(rel, ignore_patterns):
continue

findings.extend(scan_file(file_path))

return findings


def scan_content(content: str, context: str = "content") -> List[Finding]:
def scan_content(
content: str,
context: str = "content",
config_threshold: Optional[float] = None,
) -> List[Finding]:
"""Quick scan of arbitrary content (convenience wrapper for hooks).

No file I/O — purely in-memory.

Args:
content: The text to scan.
context: Label used as the ``file`` field in findings.
config_threshold: Entropy threshold override for patterns without a threshold.

Returns:
List of Finding objects.
"""
return scan_text(content, filename=context)
return scan_text(content, filename=context, config_threshold=config_threshold)


# ---------------------------------------------------------------------------
Expand All @@ -284,7 +298,9 @@ def _scan_staged(root: Path) -> List[Finding]:
rel_name = rel_name.strip()
if not rel_name:
continue
file_path = root / rel_name
file_path = (root / rel_name).resolve()
if not file_path.is_relative_to(root):
continue
if _should_skip_path(file_path):
continue
findings.extend(scan_file(file_path))
Expand Down
Loading
Loading