Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions snatch/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Cache management for downloaded media information"""
import heapq
import threading
import logging
import json
Expand Down Expand Up @@ -56,14 +57,11 @@ def _cleanup_memory(self, force: bool = False) -> None:
self._memory_cache.pop(k, None)
self._access_times.pop(k, None)

# If still too many entries, remove oldest
# If still too many entries, evict oldest via heapq (O(n+k) vs O(n log n))
if len(self._memory_cache) > self.max_memory_entries:
sorted_items = sorted(
self._access_times.items(),
key=lambda x: x[1]
)
to_remove = len(self._memory_cache) - self.max_memory_entries
for k, _ in sorted_items[:to_remove]:
oldest = heapq.nsmallest(to_remove, self._access_times.items(), key=lambda x: x[1])
for k, _ in oldest:
self._memory_cache.pop(k, None)
self._access_times.pop(k, None)

Expand Down
37 changes: 27 additions & 10 deletions snatch/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Enhanced CLI module with Rich interface and preset support.

Rich Console and traceback are lazily initialized to speed up CLI startup.
"""

import asyncio
Expand All @@ -13,11 +15,7 @@
from typing import List, Optional, Dict, Any, NoReturn
import typer
import yaml
from rich.console import Console
from rich.traceback import install
from rich.prompt import Confirm
from rich.live import Live
from rich.table import Table

# Local imports
from .constants import VERSION, EXAMPLES, APP_NAME
Expand All @@ -34,11 +32,26 @@
from .customization_manager import CustomizationManager, ThemePreset, ConfigFormat, InterfaceMode, ProgressStyle, NotificationLevel
from .audio_processor import EnhancedAudioProcessor, AudioEnhancementSettings, AUDIO_ENHANCEMENT_PRESETS

# Enable Rich traceback formatting
install(show_locals=True)
# --- Lazy Rich Console (deferred from module level) ---
_console = None


def get_console():
"""Lazy Console factory — avoids creating Console until first use."""
global _console
if _console is None:
from rich.console import Console
_console = Console()
return _console


class _LazyConsole:
"""Proxy that defers Console creation until first attribute access."""
def __getattr__(self, name):
return getattr(get_console(), name)


# Initialize console
console = Console()
console = _LazyConsole()

# Constants for duplicate strings
FALLBACK_INTERACTIVE_MSG = "[yellow]Falling back to enhanced interactive mode.[/]"
Expand Down Expand Up @@ -75,10 +88,10 @@ class EnhancedCLI:
def __init__(self, config: Dict[str, Any]):
if not config:
raise ValueError("Configuration must be provided")

self.config = config
self._pending_download = None # Store pending download for async execution

# Initialize error handler
error_log_path = config.get("error_log_path", "logs/snatch_errors.log")
self.error_handler = EnhancedErrorHandler(log_file=error_log_path)
Expand Down Expand Up @@ -2204,6 +2217,10 @@ async def _p2p_library_command(self, action: str, library_name: str, directory:
def main():
"""Main entry point for the CLI application"""
try:
# Enable Rich traceback formatting (deferred from module level)
from rich.traceback import install
install(show_locals=True)

# Initialize configuration
config = asyncio.run(initialize_config_async())

Expand Down
19 changes: 16 additions & 3 deletions snatch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,32 @@ def _ensure_config_directory() -> None:
if config_dir:
os.makedirs(config_dir, exist_ok=True)

_cached_config: Optional[Dict[str, Any]] = None
_config_mtime: float = 0.0


def _load_existing_config() -> Dict[str, Any]:
"""Load existing config file or return defaults"""
"""Load existing config file or return cached copy if unchanged on disk."""
global _cached_config, _config_mtime

config = DEFAULT_CONFIG.copy()

if os.path.exists(CONFIG_FILE):
try:
current_mtime = os.path.getmtime(CONFIG_FILE)
if _cached_config is not None and current_mtime == _config_mtime:
return _cached_config.copy()

with open(CONFIG_FILE) as f:
loaded_config = json.load(f)
if isinstance(loaded_config, dict):
config.update(loaded_config)

_cached_config = config
_config_mtime = current_mtime
except (json.JSONDecodeError, TypeError) as e:
logger.error(f"Failed to parse config file: {e}")

return config

def _ensure_output_directories(config: Dict[str, Any]) -> None:
Expand Down
6 changes: 4 additions & 2 deletions snatch/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ def _check_internet_connection(self, error_info: ErrorInfo) -> bool:
"""Check if internet connection is available"""
try:
import urllib.request
import urllib.error
urllib.request.urlopen('http://www.google.com', timeout=5)
return True
except:
except (OSError, urllib.error.URLError):
logging.error("No internet connection available")
return False

Expand Down Expand Up @@ -462,7 +463,8 @@ async def async_wrapper(*args, **kwargs):
raise
return None

return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
import inspect
return async_wrapper if inspect.iscoroutinefunction(func) else wrapper
return decorator

@contextmanager
Expand Down
84 changes: 62 additions & 22 deletions snatch/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import re
import threading
import time
from io import BytesIO
import aiohttp
import aiofiles
import backoff
from abc import ABC, abstractmethod
from contextlib import contextmanager, asynccontextmanager
Expand Down Expand Up @@ -935,17 +937,31 @@ def _import_dependencies(self):
self.yt_dlp_available = False
logging.warning("yt-dlp not available, some functionality may be limited")

def _create_http_client(self) -> aiohttp.ClientSession:
"""Create an aiohttp session with connection pooling and timeouts."""
connector_kwargs = {
"limit": 30,
"limit_per_host": 10,
"ttl_dns_cache": 300,
}
# enable_cleanup_closed is deprecated in Python 3.14+ (CPython fix)
if sys.version_info < (3, 14):
connector_kwargs["enable_cleanup_closed"] = True
connector = aiohttp.TCPConnector(**connector_kwargs)
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=60)
return aiohttp.ClientSession(connector=connector, timeout=timeout)

@property
def http_client(self) -> HTTPClientProtocol:
"""Get the HTTP client session, creating it if needed"""
if not self._http_client:
self._http_client = aiohttp.ClientSession()
self._http_client = self._create_http_client()
return self._http_client

async def __aenter__(self):
"""Async context manager entry"""
if not self._http_client and not self.user_provided_client:
self._http_client = aiohttp.ClientSession()
self._http_client = self._create_http_client()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -964,22 +980,29 @@ async def _calculate_sha256(self, data: bytes) -> str:
max_tries=5
)
async def _download_chunk(self, url: str, chunk: DownloadChunk) -> bool:
"""Download a single chunk with retries and exponential backoff"""
"""Download a single chunk with streaming reads and inline hashing"""
headers = {"Range": f"bytes={chunk.start}-{chunk.end}"}

try:
async with self.http_client.get(url, headers=headers) as response:
if response.status != 206:
logging.error(f"Range request failed: got status {response.status}")
return False

chunk.data = await response.read()
chunk.sha256 = await self._calculate_sha256(chunk.data)


# Stream data in 64KB sub-chunks, compute SHA256 inline
buffer = BytesIO()
hasher = hashlib.sha256()
async for data in response.content.iter_chunked(64 * 1024):
buffer.write(data)
hasher.update(data)

chunk.data = buffer.getvalue()
chunk.sha256 = hasher.hexdigest()

# Notify hooks
for hook in self.hooks:
await hook.post_chunk(chunk, chunk.sha256)

return True
except Exception as e:
logging.error(f"Chunk download error: {str(e)}")
Expand Down Expand Up @@ -1070,19 +1093,36 @@ async def download(self, url: str, output_path: str, **options) -> str:
progress.update(task, completed=resume_from)

try:
with open(temp_path, mode) as f:
for chunk in chunks:
success = await self._download_chunk(url, chunk)
if not success:
raise DownloadError(f"Failed to download chunk {chunk.start}-{chunk.end}")

f.write(chunk.data)
progress.update(task, advance=len(chunk.data))

# Update session
downloaded = chunk.end + 1
self.session_manager.update_session(url, {"progress": downloaded / total_size * 100})
# Rename temp file to final
max_concurrent = self.config.get("max_concurrent_chunks", 8)
semaphore = asyncio.Semaphore(max_concurrent)

async def _download_with_limit(c):
async with semaphore:
return await self._download_chunk(url, c)

async with aiofiles.open(temp_path, mode) as f:
# Download in parallel batches, write in order
for i in range(0, len(chunks), max_concurrent):
batch = chunks[i:i + max_concurrent]
results = await asyncio.gather(
*[_download_with_limit(c) for c in batch],
return_exceptions=True,
)
for c, result in zip(batch, results):
if isinstance(result, Exception):
raise DownloadError(f"Chunk {c.start}-{c.end} failed: {result}")
if not result:
raise DownloadError(f"Failed to download chunk {c.start}-{c.end}")

await f.write(c.data)
progress.update(task, advance=len(c.data))
c.data = None # Free memory eagerly

# Update session
downloaded = c.end + 1
self.session_manager.update_session(url, {"progress": downloaded / total_size * 100})

# Rename temp file to final
os.replace(temp_path, output_path)
except Exception as e:
logging.error(f"Download failed: {str(e)}")
Expand Down
20 changes: 9 additions & 11 deletions snatch/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,29 +196,27 @@ async def check_connection(self) -> bool:
return self.connection_status

async def _perform_connection_check(self) -> bool:
"""Perform actual connection check"""
# Check multiple reliable endpoints
"""Perform actual connection check — tests endpoints in parallel."""
test_endpoints = [
"https://www.google.com",
"https://www.cloudflare.com",
"https://www.microsoft.com",
"https://www.apple.com"
]

# Try to connect to each with a short timeout

timeout = aiohttp.ClientTimeout(total=3)
for endpoint in test_endpoints:

async def _check_one(endpoint: str) -> bool:
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.head(endpoint) as response:
if response.status < 400:
return True
return response.status < 400
except Exception as e:
# Just try the next endpoint
logger.debug(f"Connection check failed for {endpoint}: {e}")
continue

return False
return False

results = await asyncio.gather(*[_check_one(ep) for ep in test_endpoints])
return any(results)

async def get_connection_info(self) -> Dict[str, Any]:
"""Get detailed connection information"""
Expand Down
2 changes: 1 addition & 1 deletion snatch/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -2649,7 +2649,7 @@ async def _broadcast_discovery(self, network_addr: str, broadcast_addr: str) ->
finally:
try:
sock.close()
except:
except Exception:
pass

return discovered
Expand Down
2 changes: 1 addition & 1 deletion snatch/performance_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _collect_metrics(self) -> PerformanceMetrics:
"""Collect current system and application metrics"""
try:
# System metrics
cpu_percent = psutil.cpu_percent(interval=1)
cpu_percent = psutil.cpu_percent(interval=None)
memory = psutil.virtual_memory()
disk_io = psutil.disk_io_counters()
network_io = psutil.net_io_counters()
Expand Down
18 changes: 7 additions & 11 deletions snatch/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,6 @@ def _update_speed_metrics(self) -> None:
if len(self._speed_samples) > self._max_samples:
self._speed_samples.pop(0)

# Limit samples list size
if len(self._speed_samples) > self._max_samples:
self._speed_samples.pop(0)

# Update last sample info
self._last_sample_time = now
self._last_sample_bytes = self.downloaded
Expand Down Expand Up @@ -690,7 +686,7 @@ def update(self, n: int = 1) -> None:
try:
self.progress.update(n)
except Exception:
pass # Last resort if everything fails
logging.debug("Progress bar update failed", exc_info=True)

def set_description(self, description: str) -> None:
self.progress.set_description_str(description)
Expand Down Expand Up @@ -785,27 +781,27 @@ def __init__(self, message: str = "", style: str = "dots", color: str = "cyan"):
# Keep track of terminal width for dynamic resizing
try:
self.term_width = shutil.get_terminal_size().columns
except:
except Exception:
self.term_width = 80

def start(self):
"""Start the spinner animation in a separate thread"""
if self.running:
return

self.running = True
self._stop_event.clear()
self._pause_event.clear()

def spin():
index = 0
while not self._stop_event.is_set():
if not self._pause_event.is_set():
try:
# Get current terminal width for proper wrapping
self.term_width = shutil.get_terminal_size().columns
except:
pass
except Exception:
logging.debug("Terminal size detection failed")

status = f"\r{self.color}{self.frames[index]}{Style.RESET_ALL} {self.message}"

Expand Down
Loading
Loading