diff --git a/pyproject.toml b/pyproject.toml index c64cdb8..96e85bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ ] [project.optional-dependencies] -clawhub = "clawhub>=0.1.0" +clawhub = ["clawhub>=0.1.0"] [project.scripts] ocbs = "ocbs.cli:main" diff --git a/src/ocbs/core.py b/src/ocbs/core.py index bdc36cc..865a901 100644 --- a/src/ocbs/core.py +++ b/src/ocbs/core.py @@ -10,6 +10,7 @@ import subprocess import tarfile import tempfile +import sys from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -221,10 +222,8 @@ def _resolve_restore_path(self, file_path: str, target_dir: Path) -> Path: base_dir = target_dir.resolve() full_path = (base_dir / rel_path).resolve() - try: - full_path.relative_to(base_dir) - except ValueError as exc: - raise ValueError(f"restore path escapes target directory: {file_path}") from exc + if not full_path.is_relative_to(base_dir): + raise ValueError(f"restore path escapes target directory: {file_path}") return full_path @@ -330,6 +329,7 @@ def _record_backup( scope: BackupScope, reason: str, files: Iterable[tuple[str, bytes]], + total_files: Optional[int] = None, ) -> BackupManifest: """Store backup file content and metadata.""" @@ -376,19 +376,42 @@ def _record_backup( (backup_id, relative_path, chunk.chunk_id), ) + if total_files and sys.stdout.isatty(): + print(f"\rBacking up... {len(manifest.paths)}/{total_files} files processed", end="", flush=True) + elif not sys.stdout.isatty(): + # Non-TTY: log progress at dynamic intervals + interval = max(1, total_files // 100) if total_files else 100 + if len(manifest.paths) % interval == 0: + if total_files: + print(f"Processed {len(manifest.paths)} / {total_files} files") + else: + print(f"Processed {len(manifest.paths)} files") + + if sys.stdout.isatty() and total_files: + print() + elif not sys.stdout.isatty(): + # Non-TTY: always emit final completion log + if total_files: + print(f"Completed: processed {len(manifest.paths)} / {total_files} files") + else: + print(f"Completed: processed {len(manifest.paths)} files") + return manifest def _backup_direct(self, scope: BackupScope, reason: str = "") -> BackupManifest: """Back up files directly from the OpenClaw home.""" + paths = self._get_paths_for_scope(scope) + all_files = self._collect_files(paths) + total_files = len(all_files) + def _file_gen(): - paths = self._get_paths_for_scope(scope) - for file_path in self._collect_files(paths): + for file_path in all_files: rel_path = str(file_path.relative_to(Path.home())) yield (rel_path, file_path.read_bytes()) backup_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - return self._record_backup(backup_id, scope, reason, _file_gen()) + return self._record_backup(backup_id, scope, reason, _file_gen(), total_files=total_files) def _run_native_backup(self, scope: BackupScope, dry_run: bool = False) -> Path: """Run OpenClaw native backup and return the archive path.""" @@ -475,16 +498,15 @@ def _chunk_archive(self, archive_path: Path, scope: BackupScope, reason: str = " with tarfile.open(archive_path, "r:gz") as tar: self._safe_extract_archive(tar, extract_dir) + all_files = [p for p in extract_dir.rglob("*") if p.is_file() and p.name != "manifest.json"] + total_files = len(all_files) + def _file_gen(): - for file_path in sorted(extract_dir.rglob("*")): - if not file_path.is_file(): - continue + for file_path in all_files: rel_path = file_path.relative_to(extract_dir) - if rel_path == Path("manifest.json"): - continue yield (str(rel_path), file_path.read_bytes()) - return self._record_backup(backup_id, scope, reason, _file_gen()) + return self._record_backup(backup_id, scope, reason, _file_gen(), total_files=total_files) def backup( self, @@ -891,4 +913,4 @@ def get_checkpoint_serves(self, checkpoint_id: str) -> list[dict]: 'restored': bool(row[6]) } for row in cursor.fetchall() - ] + ] \ No newline at end of file diff --git a/src/ocbs/serve.py b/src/ocbs/serve.py index 7fec8fe..a3f03b4 100644 --- a/src/ocbs/serve.py +++ b/src/ocbs/serve.py @@ -6,6 +6,7 @@ import socket import subprocess import urllib.parse +from http.server import HTTPServer from pathlib import Path from typing import Optional from urllib.parse import urlencode @@ -838,52 +839,38 @@ def start_restore_server(port: Optional[int] = 3456, bind_host: str = '127.0.0.1 # Test detection conn_type, host = detect_connection_type() print(f"Detected connection: {conn_type} ({host})") - - def stop(self): - """Stop the HTTP server.""" - global _global_server - if self.server: - self.server.shutdown() - self.server.server_close() - # Clear global reference if this is the global server - if _global_server is self: - _global_server = None -# Global server instance for convenience functions -_global_server: Optional[RestorePageServer] = None +def generate_restore_url(checkpoint_id: str, port: int = 3456, host: str = "localhost") -> str: + """Generate a restore URL for a checkpoint. + Args: + checkpoint_id: The checkpoint ID (or token) to restore + port: Server port (default: 3456, matching start_restore_server default) + host: Server host (default: localhost) -def start_restore_server(port: int = 18790, host: str = "localhost", - bind_host: str = "127.0.0.1", state_dir: Optional[Path] = None): - """Start the restore server in the background.""" - global _global_server - if _global_server is None: - _global_server = RestorePageServer(state_dir=state_dir, port=port, host=host, bind_host=bind_host) - _global_server.start(background=True) - return _global_server + Returns: + The tokenized restore endpoint URL: /restore/ + """ + return f"http://{host}:{port}/restore/{checkpoint_id}" def format_restore_message(checkpoint_id: str, reason: str, - port: int = 18790, host: str = "localhost") -> str: - """Format a restore message with URL for a checkpoint.""" - global _global_server - - # Ensure server is running - if _global_server is None: - start_restore_server(port=port, host=host) + port: int = 3456, host: str = "localhost") -> str: + """Format a restore message with URL for a checkpoint. - # Create serve record for this checkpoint - token = _global_server.serve_checkpoint(checkpoint_id) - url = _global_server.get_restore_url(token) + Args: + checkpoint_id: The checkpoint ID (or token) + reason: Reason for the checkpoint + port: Server port (default: 3456, matching start_restore_server default) + host: Server host (default: localhost) - message = f""" -Checkpoint created: {checkpoint_id} + Returns: + Formatted message including the restore URL + """ + url = generate_restore_url(checkpoint_id, port, host) + return f"""Checkpoint created: {checkpoint_id} Reason: {reason} -Restore URL (expires in 4 hours): -{url} - -Share this URL to allow emergency restore of this checkpoint. -""" - return message.strip() \ No newline at end of file +Restore URL: {url} +""" \ No newline at end of file diff --git a/src/ocbs/skill.py b/src/ocbs/skill.py index a0b5e47..34c06ef 100644 --- a/src/ocbs/skill.py +++ b/src/ocbs/skill.py @@ -28,9 +28,6 @@ from .core import BackupSource, BackupScope, OCBSCore from .serve import generate_restore_url, format_restore_message, start_restore_server -from .core import OCBSCore, BackupScope -from .serve import RestorePageServer - class OCBSBackupSkill: """Skill that exposes OCBS commands via chat.""" diff --git a/tests/test_core.py b/tests/test_core.py index 6911517..430e891 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -505,10 +505,62 @@ def test_cleanup(self, ocbs, sample_files, temp_state_dir): # Run cleanup ocbs.cleanup(BackupScope.CONFIG) - + # Should still have backups backups = ocbs.list_backups(BackupScope.CONFIG) assert len(backups) > 0 finally: if original_home: os.environ['HOME'] = original_home + + +class TestResolveRestorePath: + """Tests for _resolve_restore_path path validation.""" + + def test_normal_relative_path(self, ocbs, temp_state_dir): + """Test that normal relative paths are resolved correctly.""" + target = temp_state_dir / "restore" + result = ocbs._resolve_restore_path("config/settings.json", target) + assert result.resolve() == (target / "config" / "settings.json").resolve() + + def test_openclaw_stripped(self, ocbs, temp_state_dir): + """Test that .openclaw prefix is stripped.""" + target = temp_state_dir / "restore" + result = ocbs._resolve_restore_path(".openclaw/config/settings.json", target) + assert result.resolve() == (target / "config" / "settings.json").resolve() + + def test_absolute_path_rejected(self, ocbs, temp_state_dir): + """Test that absolute paths are rejected.""" + target = temp_state_dir / "restore" + with pytest.raises(ValueError, match="absolute restore paths are not allowed"): + ocbs._resolve_restore_path("/etc/passwd", target) + + def test_empty_path_rejected(self, ocbs, temp_state_dir): + """Test that empty paths are rejected.""" + target = temp_state_dir / "restore" + with pytest.raises(ValueError, match="empty restore path is not allowed"): + ocbs._resolve_restore_path("", target) + + def test_path_traversal_rejected(self, ocbs, temp_state_dir): + """Test that path traversal attempts are rejected.""" + target = temp_state_dir / "restore" + with pytest.raises(ValueError, match="restore path escapes target directory"): + ocbs._resolve_restore_path("../etc/passwd", target) + + def test_deep_path_traversal_rejected(self, ocbs, temp_state_dir): + """Test that deep path traversal attempts are rejected.""" + target = temp_state_dir / "restore" + with pytest.raises(ValueError, match="restore path escapes target directory"): + ocbs._resolve_restore_path("subdir/../../../../etc/passwd", target) + + def test_symlink_escape_via_realpath(self, ocbs, temp_state_dir): + """Test that symlinks that escape target via realpath are rejected.""" + target = temp_state_dir / "restore" + # Create a symlink inside target that points to parent + subdir = target / "subdir" + subdir.mkdir(parents=True) + symlink = subdir / "link" + symlink.symlink_to(temp_state_dir.parent) + # This path resolves via the symlink to escape the target + with pytest.raises(ValueError, match="restore path escapes target directory"): + ocbs._resolve_restore_path("subdir/link/../etc/passwd", target)