From 275b3afd6b5c96b84dda542afb32d0b6d2e805ed Mon Sep 17 00:00:00 2001 From: Aryan Kumar <156166681+aryan5v@users.noreply.github.com> Date: Wed, 29 Apr 2026 08:58:26 -0700 Subject: [PATCH] feat: add image file and dataset attachments --- README.md | 34 +++ agent/core/agent_loop.py | 42 ++- agent/core/attachments.py | 324 +++++++++++++++++++++ agent/main.py | 194 +++++++++++- agent/utils/terminal_display.py | 2 + backend/routes/agent.py | 96 +++++- backend/session_manager.py | 9 +- frontend/src/components/Chat/ChatInput.tsx | 159 +++++++++- frontend/src/components/SessionChat.tsx | 7 +- frontend/src/lib/sse-chat-transport.ts | 5 +- frontend/src/utils/api.ts | 5 +- tests/unit/test_attachments.py | 90 ++++++ 12 files changed, 932 insertions(+), 35 deletions(-) create mode 100644 agent/core/attachments.py create mode 100644 tests/unit/test_attachments.py diff --git a/README.md b/README.md index 8a6c1ccd..f30d4e25 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,40 @@ ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` +**Attach local files and datasets:** + +Use `--file` or `--image` when a local file should be visible to the next +agent turn only: + +```bash +ml-intern "summarize this CSV" --file ./data.csv +ml-intern "what does this screenshot show?" --image ./screenshot.png +``` + +In interactive mode, queue local files for the next submitted message: + +```text +/attach ./data.csv ./notes.txt +/attach ./screenshot.png +``` + +Use `--dataset` or `/dataset` when the file should be imported to a private +Hugging Face dataset repo for training jobs or later reuse: + +```bash +ml-intern "fine-tune on this data" --dataset ./train.jsonl +``` + +```text +/dataset ./train.jsonl +``` + +Dataset imports are stored under a private repo named +`{username}/ml-intern-user-datasets`. Plain `--file`, `--image`, and `/attach` +do not upload to the Hub; they are per-turn context only. In the web UI, use +the paperclip button or drag and drop files into the composer, then choose +between **Attach to turn** and **Import as dataset** before sending. + ## Supported Gateways ML Intern currently supports one-way notification gateways from CLI sessions. diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 8b7a4572..b7902f85 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -21,6 +21,7 @@ from agent.config import Config from agent.messaging.gateway import NotificationGateway from agent.core import telemetry +from agent.core.attachments import build_user_content from agent.core.doom_loop import check_for_doom_loop from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching @@ -826,7 +827,7 @@ async def _abandon_pending_approval(session: Session) -> None: @staticmethod async def run_agent( - session: Session, text: str, + session: Session, text: str, attachments: list[dict[str, Any]] | None = None, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -840,10 +841,35 @@ async def run_agent( if text and session.pending_approval: await Handlers._abandon_pending_approval(session) - # Add user message to history only if there's actual content - if text: - user_msg = Message(role="user", content=text) - session.context_manager.add_message(user_msg) + redacted_user_msg: Message | None = None + raw_user_msg: Message | None = None + + def _redact_transient_user_content() -> None: + if raw_user_msg is not None and redacted_user_msg is not None: + raw_user_msg.content = redacted_user_msg.content + + # Add user message to history only if there's actual content. Image + # bytes are sent transiently for the current LLM turn, then replaced + # with text-only placeholders before the history is persisted/reused. + if text or attachments: + content = build_user_content(text, attachments or []) + redacted_text = build_user_content(text, [ + m for m in (attachments or []) if m.get("type") == "dataset_import" + ]) + if attachments: + # Preserve manifest/preview context in the persisted text while + # avoiding raw image data URLs. + from agent.core.attachments import attachment_note + + redacted_text = (text or "").rstrip() + attachment_note(attachments or []) + raw_user_msg = Message(role="user", content=content) + redacted_user_msg = Message(role="user", content=redacted_text) + if content != redacted_text: + session.context_manager.items.append(raw_user_msg) + if session.context_manager.on_message_added: + session.context_manager.on_message_added(redacted_user_msg) + else: + session.context_manager.add_message(raw_user_msg) # Send event that we're processing await session.send_event( @@ -1196,6 +1222,7 @@ async def _exec_tool( } # Return early - wait for EXEC_APPROVAL operation + _redact_transient_user_content() return None iteration += 1 @@ -1228,6 +1255,8 @@ async def _exec_tool( errored = True break + _redact_transient_user_content() + if session.is_cancelled: await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) @@ -1523,7 +1552,8 @@ async def process_submission(session: Session, submission) -> bool: if op.op_type == OpType.USER_INPUT: text = op.data.get("text", "") if op.data else "" - await Handlers.run_agent(session, text) + attachments = op.data.get("attachments", []) if op.data else [] + await Handlers.run_agent(session, text, attachments) return True if op.op_type == OpType.COMPACT: diff --git a/agent/core/attachments.py b/agent/core/attachments.py new file mode 100644 index 00000000..c8b1f1de --- /dev/null +++ b/agent/core/attachments.py @@ -0,0 +1,324 @@ +"""User-selected file and dataset attachment helpers. + +The agent never gets ambient access to a user's laptop. Files enter the +conversation only through this deliberate attachment layer: + +* context uploads are staged locally and summarized into the next turn; +* dataset imports are uploaded to a private HF dataset repo and represented by + a manifest the agent can use for training/jobs. +""" + +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +import re +import shutil +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + + +DEFAULT_MAX_ATTACHMENT_BYTES = 50 * 1024 * 1024 +DEFAULT_MAX_UPLOAD_BYTES = 200 * 1024 * 1024 +TEXT_PREVIEW_BYTES = 16 * 1024 +STAGING_ROOT = Path(os.environ.get("ML_INTERN_UPLOAD_DIR", "/tmp/ml-intern-uploads")) + + +class AttachmentError(ValueError): + """Readable validation/import error safe to surface to users.""" + + +@dataclass(frozen=True) +class AttachmentSource: + """A deliberate local file selected by CLI or web upload staging.""" + + path: Path + original_name: str | None = None + kind: str = "file" + + +def sanitize_filename(name: str) -> str: + """Return a conservative filename safe for HF repo paths.""" + base = Path(name).name.strip().replace("\x00", "") + base = re.sub(r"[^A-Za-z0-9._ -]+", "_", base) + base = re.sub(r"\s+", "_", base).strip("._- ") + return base[:180] or "attachment" + + +def _kind_for(path: Path, declared_kind: str = "file") -> tuple[str, str]: + mime, _ = mimetypes.guess_type(path.name) + mime = mime or "application/octet-stream" + if declared_kind == "image" or mime.startswith("image/"): + return "image", mime + return "file", mime + + +def _validate_path(path: Path, *, max_bytes: int) -> int: + if not path.exists(): + raise AttachmentError(f"Attachment not found: {path}") + if not path.is_file(): + raise AttachmentError(f"Attachment is not a file: {path}") + size = path.stat().st_size + if size <= 0: + raise AttachmentError(f"Attachment is empty: {path}") + if size > max_bytes: + mb = max_bytes // (1024 * 1024) + raise AttachmentError(f"Attachment is too large: {path} ({size} bytes, max {mb} MB)") + return size + + +def _text_preview(path: Path, mime_type: str) -> str | None: + text_like = ( + mime_type.startswith("text/") + or path.suffix.lower() + in { + ".csv", + ".json", + ".jsonl", + ".md", + ".py", + ".txt", + ".tsv", + ".yaml", + ".yml", + } + ) + if not text_like: + return None + raw = path.read_bytes()[:TEXT_PREVIEW_BYTES] + if not raw: + return None + try: + text = raw.decode("utf-8") + except UnicodeDecodeError: + text = raw.decode("utf-8", errors="replace") + return text + + +def _file_item(path: Path, *, original_name: str | None, declared_kind: str) -> dict[str, Any]: + size = _validate_path(path, max_bytes=DEFAULT_MAX_ATTACHMENT_BYTES) + filename = sanitize_filename(original_name or path.name) + kind, mime_type = _kind_for(path, declared_kind) + item: dict[str, Any] = { + "kind": kind, + "filename": filename, + "original_name": original_name or path.name, + "size_bytes": size, + "mime_type": mime_type, + } + preview = _text_preview(path, mime_type) + if preview: + item["text_preview"] = preview + item["preview_truncated"] = size > TEXT_PREVIEW_BYTES + return item + + +def create_context_manifest( + sources: Iterable[AttachmentSource], + *, + scope_id: str, + upload_id: str | None = None, + copy_to_staging: bool = False, +) -> dict[str, Any]: + """Create a local-only attachment manifest for one agent turn.""" + upload_id = upload_id or uuid.uuid4().hex + prefix = STAGING_ROOT / scope_id / upload_id + if copy_to_staging: + prefix.mkdir(parents=True, exist_ok=True) + + items: list[dict[str, Any]] = [] + for idx, source in enumerate(sources, start=1): + path = Path(source.path).expanduser() + item = _file_item(path, original_name=source.original_name, declared_kind=source.kind) + staged_path = path + if copy_to_staging: + staged_path = prefix / f"{idx:03d}-{item['filename']}" + shutil.copyfile(path, staged_path) + item["path"] = str(staged_path) + item["placeholder"] = f"[{'Image' if item['kind'] == 'image' else 'File'} #{idx}]" + items.append(item) + + if not items: + raise AttachmentError("No attachments were provided.") + + manifest = { + "type": "context_upload", + "upload_id": upload_id, + "scope_id": scope_id, + "items": items, + } + if copy_to_staging: + manifest_path = prefix / "manifest.json" + manifest_path.write_text(json.dumps(_manifest_without_previews(manifest), indent=2), encoding="utf-8") + manifest["manifest_path"] = str(manifest_path) + return manifest + + +def load_context_manifest(scope_id: str, upload_id: str) -> dict[str, Any]: + manifest_path = STAGING_ROOT / scope_id / upload_id / "manifest.json" + if not manifest_path.exists(): + raise AttachmentError(f"Upload not found: {upload_id}") + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + # Rehydrate previews from staged files; they are intentionally not persisted + # to the JSON manifest to keep local history compact. + for item in manifest.get("items", []): + path = Path(item.get("path", "")) + if path.exists(): + preview = _text_preview(path, item.get("mime_type") or "") + if preview: + item["text_preview"] = preview + item["preview_truncated"] = path.stat().st_size > TEXT_PREVIEW_BYTES + return manifest + + +def _manifest_without_previews(manifest: dict[str, Any]) -> dict[str, Any]: + clean = dict(manifest) + clean["items"] = [ + {k: v for k, v in item.items() if k not in {"text_preview"}} + for item in manifest.get("items", []) + ] + return clean + + +def default_dataset_repo_id(username: str) -> str: + return f"{username}/ml-intern-user-datasets" + + +def _repo_username(api: Any, token: str | None) -> str: + whoami = api.whoami(token=token) if token else api.whoami() + username = whoami.get("name") or whoami.get("fullname") + if not username: + raise AttachmentError("Could not resolve Hugging Face username for dataset import.") + return username + + +def import_dataset_batch( + sources: Iterable[AttachmentSource], + *, + token: str | None, + scope_id: str, + upload_id: str | None = None, + repo_id: str | None = None, + api: Any | None = None, +) -> dict[str, Any]: + """Upload files plus manifest.json to a private HF dataset repo.""" + try: + from huggingface_hub import HfApi + except Exception as exc: # pragma: no cover - import guard + raise AttachmentError("huggingface_hub is required for dataset imports.") from exc + + api = api or HfApi(token=token) + if repo_id is None: + repo_id = default_dataset_repo_id(_repo_username(api, token)) + upload_id = upload_id or uuid.uuid4().hex + path_prefix = f"sessions/{sanitize_filename(scope_id)}/{upload_id}" + + staged = create_context_manifest( + sources, + scope_id=scope_id, + upload_id=upload_id, + copy_to_staging=False, + ) + api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True, token=token) + items: list[dict[str, Any]] = [] + for idx, item in enumerate(staged["items"], start=1): + source_path = item["path"] + filename = f"{idx:03d}-{item['filename']}" + path_in_repo = f"{path_prefix}/files/{filename}" + api.upload_file( + path_or_fileobj=source_path, + path_in_repo=path_in_repo, + repo_id=repo_id, + repo_type="dataset", + token=token, + commit_message=f"Add ML Intern dataset file {filename}", + ) + clean_item = {k: v for k, v in item.items() if k not in {"path", "text_preview"}} + clean_item["path_in_repo"] = path_in_repo + items.append(clean_item) + + manifest = { + "type": "dataset_import", + "upload_id": upload_id, + "scope_id": scope_id, + "repo_id": repo_id, + "repo_type": "dataset", + "path_prefix": path_prefix, + "manifest_path": f"{path_prefix}/manifest.json", + "items": items, + } + import tempfile + + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as tmp: + json.dump(manifest, tmp, indent=2) + tmp_path = tmp.name + try: + api.upload_file( + path_or_fileobj=tmp_path, + path_in_repo=manifest["manifest_path"], + repo_id=repo_id, + repo_type="dataset", + token=token, + commit_message="Add ML Intern dataset manifest", + ) + finally: + try: + os.unlink(tmp_path) + except OSError: + pass + return manifest + + +def attachment_note(manifests: Iterable[dict[str, Any]]) -> str: + """Build a text note injected into the model turn.""" + lines = ["\n\n[Attached context]"] + for manifest in manifests: + if manifest.get("type") == "dataset_import": + lines.append( + "- Imported dataset batch: " + f"repo={manifest.get('repo_id')} " + f"path_prefix={manifest.get('path_prefix')} " + f"manifest={manifest.get('manifest_path')}. " + "Use this HF dataset path for training, jobs, and durable reuse." + ) + else: + lines.append(f"- Local per-turn attachment batch: upload_id={manifest.get('upload_id')}.") + for item in manifest.get("items", []): + lines.append( + f" {item.get('placeholder', '')} {item.get('filename')} " + f"({item.get('mime_type')}, {item.get('size_bytes')} bytes)" + ) + preview = item.get("text_preview") + if preview: + suffix = "\n [preview truncated]" if item.get("preview_truncated") else "" + lines.append(f" Preview:\n```text\n{preview}\n```{suffix}") + lines.append( + "Only use files explicitly listed above. Local per-turn files are not durable; " + "ask the user to import them as a dataset if an HF Job or later turn needs full access." + ) + return "\n".join(lines) + + +def build_user_content(text: str, manifests: Iterable[dict[str, Any]]) -> str | list[dict[str, Any]]: + """Return LiteLLM-compatible user content with transient image parts.""" + manifest_list = list(manifests) + note = attachment_note(manifest_list) if manifest_list else "" + text_part = (text or "").rstrip() + note + parts: list[dict[str, Any]] = [{"type": "text", "text": text_part}] + for manifest in manifest_list: + if manifest.get("type") == "dataset_import": + continue + for item in manifest.get("items", []): + if item.get("kind") != "image": + continue + path = Path(item.get("path", "")) + if not path.exists(): + continue + data = base64.b64encode(path.read_bytes()).decode("ascii") + mime_type = item.get("mime_type") or "image/png" + parts.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{data}"}}) + return parts if len(parts) > 1 else text_part diff --git a/agent/main.py b/agent/main.py index f500cc5f..786e030b 100644 --- a/agent/main.py +++ b/agent/main.py @@ -10,9 +10,11 @@ import asyncio import json import os +import shlex import signal import sys import time +import uuid from dataclasses import dataclass from pathlib import Path from typing import Any, Optional @@ -21,6 +23,12 @@ from prompt_toolkit import PromptSession from agent.config import load_config +from agent.core.attachments import ( + AttachmentError, + AttachmentSource, + create_context_manifest, + import_dataset_batch, +) from agent.core.agent_loop import submission_loop from agent.core import model_switcher from agent.core.hf_tokens import resolve_hf_token @@ -55,6 +63,41 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json" +def _split_path_args(values: list[str] | None) -> list[str]: + paths: list[str] = [] + for value in values or []: + for part in value.split(","): + part = part.strip() + if part: + paths.append(part) + return paths + + +def _sources_from_paths(paths: list[str], *, kind: str = "file") -> list[AttachmentSource]: + return [ + AttachmentSource(path=Path(path).expanduser(), original_name=Path(path).name, kind=kind) + for path in paths + ] + + +def _placeholders_for(manifests: list[dict[str, Any]]) -> str: + labels: list[str] = [] + for manifest in manifests: + for item in manifest.get("items", []): + placeholder = item.get("placeholder") + filename = item.get("filename") + if placeholder and filename: + labels.append(f"{placeholder} {filename}") + return "\n".join(labels) + + +def _compose_display_text(text: str, manifests: list[dict[str, Any]]) -> str: + placeholders = _placeholders_for(manifests) + if not placeholders: + return text + return f"{text.rstrip()}\n\n{placeholders}" + + def _configure_runtime_logging() -> None: """Keep third-party warning spam from punching through the interactive UI.""" import logging @@ -711,6 +754,9 @@ async def _handle_slash_command( session_holder: list, submission_queue: asyncio.Queue, submission_id: list[int], + pending_attachments: list[dict[str, Any]], + pending_datasets: list[dict[str, Any]], + hf_token: str | None, ) -> Submission | None: """ Handle a slash command. Returns a Submission to enqueue, or None if @@ -741,6 +787,45 @@ async def _handle_slash_command( operation=Operation(op_type=OpType.COMPACT), ) + if command in {"/attach", "/dataset"}: + console = get_console() + if not arg: + console.print(f"[bold red]Usage:[/bold red] {command} PATH [PATH ...]") + return None + try: + paths = shlex.split(arg) + except ValueError as exc: + console.print(f"[bold red]Could not parse paths:[/bold red] {exc}") + return None + scope_id = getattr(session_holder[0], "session_id", None) or f"cli-{uuid.uuid4().hex[:8]}" + try: + if command == "/attach": + manifest = create_context_manifest( + _sources_from_paths(paths), + scope_id=scope_id, + copy_to_staging=False, + ) + pending_attachments.append(manifest) + console.print(f"[green]Attached for next turn:[/green] {', '.join(paths)}") + else: + if not hf_token: + console.print("[bold red]A Hugging Face token is required for /dataset.[/bold red]") + return None + manifest = await asyncio.to_thread( + import_dataset_batch, + _sources_from_paths(paths), + token=hf_token, + scope_id=scope_id, + ) + pending_datasets.append(manifest) + console.print( + "[green]Dataset imported for next turn:[/green] " + f"{manifest['repo_id']}:{manifest['path_prefix']}" + ) + except AttachmentError as exc: + console.print(f"[bold red]Attachment error:[/bold red] {exc}") + return None + if command == "/model": console = get_console() if not arg: @@ -811,7 +896,12 @@ async def _handle_slash_command( return None -async def main(model: str | None = None): +async def main( + model: str | None = None, + file_paths: list[str] | None = None, + image_paths: list[str] | None = None, + dataset_paths: list[str] | None = None, +): """Interactive chat with the agent""" # Clear screen @@ -889,6 +979,26 @@ async def main(model: str | None = None): await ready_event.wait() submission_id = [0] + pending_attachments: list[dict[str, Any]] = [] + pending_datasets: list[dict[str, Any]] = [] + scope_id = getattr(session_holder[0], "session_id", None) or f"cli-{uuid.uuid4().hex[:8]}" + try: + initial_sources = _sources_from_paths(file_paths or []) + _sources_from_paths(image_paths or [], kind="image") + if initial_sources: + pending_attachments.append( + create_context_manifest(initial_sources, scope_id=scope_id, copy_to_staging=False) + ) + if dataset_paths: + pending_datasets.append( + await asyncio.to_thread( + import_dataset_batch, + _sources_from_paths(dataset_paths), + token=hf_token, + scope_id=scope_id, + ) + ) + except AttachmentError as exc: + get_console().print(f"[bold red]Attachment error:[/bold red] {exc}") # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 # (`QUIT_SHORTCUT_TIMEOUT = Duration::from_secs(1)`). Two Ctrl+C presses # within this window quit; a single press cancels the in-flight turn. @@ -982,7 +1092,14 @@ def _install_sigint() -> bool: # Handle slash commands if user_input.strip().startswith("/"): sub = await _handle_slash_command( - user_input.strip(), config, session_holder, submission_queue, submission_id + user_input.strip(), + config, + session_holder, + submission_queue, + submission_id, + pending_attachments, + pending_datasets, + hf_token, ) if sub is None: # Command handled locally, loop back for input @@ -993,11 +1110,18 @@ def _install_sigint() -> bool: continue # Submit to agent + turn_attachments = [*pending_attachments, *pending_datasets] + pending_attachments.clear() + pending_datasets.clear() submission_id[0] += 1 submission = Submission( id=f"sub_{submission_id[0]}", operation=Operation( - op_type=OpType.USER_INPUT, data={"text": user_input} + op_type=OpType.USER_INPUT, + data={ + "text": _compose_display_text(user_input, turn_attachments), + "attachments": turn_attachments, + }, ), ) await submission_queue.put(submission) @@ -1039,6 +1163,9 @@ async def headless_main( model: str | None = None, max_iterations: int | None = None, stream: bool = True, + file_paths: list[str] | None = None, + image_paths: list[str] | None = None, + dataset_paths: list[str] | None = None, ) -> None: """Run a single prompt headlessly and exit.""" import logging @@ -1100,9 +1227,36 @@ async def headless_main( break # Submit the prompt + scope_id = getattr(session_holder[0], "session_id", None) or f"cli-{uuid.uuid4().hex[:8]}" + attachments: list[dict[str, Any]] = [] + try: + initial_sources = _sources_from_paths(file_paths or []) + _sources_from_paths(image_paths or [], kind="image") + if initial_sources: + attachments.append( + create_context_manifest(initial_sources, scope_id=scope_id, copy_to_staging=False) + ) + if dataset_paths: + attachments.append( + await asyncio.to_thread( + import_dataset_batch, + _sources_from_paths(dataset_paths), + token=hf_token, + scope_id=scope_id, + ) + ) + except AttachmentError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(2) + submission = Submission( id="sub_1", - operation=Operation(op_type=OpType.USER_INPUT, data={"text": prompt}), + operation=Operation( + op_type=OpType.USER_INPUT, + data={ + "text": _compose_display_text(prompt, attachments), + "attachments": attachments, + }, + ), ) await submission_queue.put(submission) @@ -1255,16 +1409,44 @@ def cli(): help="Max LLM requests per turn (default: 50, use -1 for unlimited)") parser.add_argument("--no-stream", action="store_true", help="Disable token streaming (use non-streaming LLM calls)") + parser.add_argument("--file", action="append", default=[], + help="Attach a local file for the next turn. Repeat or comma-separate paths.") + parser.add_argument("--image", action="append", default=[], + help="Attach a local image for the next turn. Repeat or comma-separate paths.") + parser.add_argument("--dataset", action="append", default=[], + help="Import local path(s) to your private HF dataset repo before the turn.") args = parser.parse_args() + file_paths = _split_path_args(args.file) + image_paths = _split_path_args(args.image) + dataset_paths = _split_path_args(args.dataset) try: if args.prompt: max_iter = args.max_iterations if max_iter is not None and max_iter < 0: max_iter = 10_000 # effectively unlimited - asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) + asyncio.run(headless_main( + args.prompt, + model=args.model, + max_iterations=max_iter, + stream=not args.no_stream, + file_paths=file_paths, + image_paths=image_paths, + dataset_paths=dataset_paths, + )) else: - asyncio.run(main(model=args.model)) + main_kwargs: dict[str, Any] = {"model": args.model} + if file_paths or image_paths or dataset_paths: + main_kwargs.update( + { + "file_paths": file_paths, + "image_paths": image_paths, + "dataset_paths": dataset_paths, + } + ) + asyncio.run( + main(**main_kwargs) + ) except KeyboardInterrupt: print("\n\nGoodbye!") diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index 8ff9d525..1671b680 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -421,6 +421,8 @@ def print_yolo_approve(count: int) -> None: {_I} [cyan]/help[/cyan] Show this help {_I} [cyan]/undo[/cyan] Undo last turn {_I} [cyan]/compact[/cyan] Compact context window +{_I} [cyan]/attach[/cyan] PATH... Attach local file(s) for the next turn +{_I} [cyan]/dataset[/cyan] PATH... Import local file(s) to private HF dataset {_I} [cyan]/model[/cyan] [id] Show available models or switch {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off) {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 96830568..3bdab3d8 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -14,8 +14,11 @@ from fastapi import ( APIRouter, Depends, + File, + Form, HTTPException, Request, + UploadFile, ) from fastapi.responses import StreamingResponse from litellm import acompletion @@ -34,6 +37,13 @@ import user_quotas from agent.core.hf_access import get_jobs_access +from agent.core.attachments import ( + AttachmentError, + AttachmentSource, + create_context_manifest, + import_dataset_batch, + load_context_manifest, +) from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token from agent.core.llm_params import _resolve_llm_params @@ -547,6 +557,89 @@ async def submit_approval( return {"status": "submitted", "session_id": request.session_id} +def _resolve_chat_uploads(session_id: str, uploads: Any) -> list[dict[str, Any]]: + if not uploads: + return [] + if not isinstance(uploads, list): + raise HTTPException(status_code=400, detail="'uploads' must be a list") + manifests: list[dict[str, Any]] = [] + for upload in uploads: + try: + if isinstance(upload, str): + manifest = load_context_manifest(session_id, upload) + elif isinstance(upload, dict) and upload.get("type") == "context_upload": + manifest = load_context_manifest(session_id, str(upload.get("upload_id") or "")) + elif isinstance(upload, dict) and upload.get("type") == "dataset_import": + manifest = upload + else: + raise AttachmentError("Malformed upload reference.") + if manifest.get("scope_id") != session_id: + raise AttachmentError("Upload does not belong to this session.") + manifests.append(manifest) + except AttachmentError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return manifests + + +@router.post("/session/{session_id}/uploads") +async def upload_session_files( + session_id: str, + request: Request, + files: list[UploadFile] = File(...), + import_as_dataset: bool = Form(False), + user: dict = Depends(get_current_user), +) -> dict: + """Stage explicit user-selected files or import them to a private HF dataset.""" + agent_session = await _check_session_access(session_id, user, request) + if not agent_session or not agent_session.is_active: + raise HTTPException(status_code=404, detail="Session not found or inactive") + if not files: + raise HTTPException(status_code=400, detail="No files uploaded") + + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory(prefix="ml-intern-upload-") as tmpdir: + sources: list[AttachmentSource] = [] + for index, upload in enumerate(files, start=1): + original_name = Path(upload.filename or "attachment").name + target = Path(tmpdir) / f"{index:03d}-{original_name}" + size = 0 + with target.open("wb") as out: + while True: + chunk = await upload.read(1024 * 1024) + if not chunk: + break + size += len(chunk) + if size > 200 * 1024 * 1024: + raise HTTPException(status_code=413, detail=f"{original_name} is too large") + out.write(chunk) + kind = "image" if (upload.content_type or "").startswith("image/") else "file" + sources.append(AttachmentSource(path=target, original_name=original_name, kind=kind)) + + try: + if import_as_dataset: + token = resolve_hf_request_token(request) + if not token: + raise AttachmentError("A Hugging Face token is required to import datasets.") + manifest = await asyncio.to_thread( + import_dataset_batch, + sources, + token=token, + scope_id=session_id, + ) + else: + manifest = create_context_manifest( + sources, + scope_id=session_id, + copy_to_staging=True, + ) + except AttachmentError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + return {"status": "uploaded", "upload": manifest} + + @router.post("/chat/{session_id}") async def chat_sse( session_id: str, @@ -569,6 +662,7 @@ async def chat_sse( # Submit the operation text = body.get("text") approvals = body.get("approvals") + uploads = _resolve_chat_uploads(session_id, body.get("uploads")) # Gate user-message sends against the daily Claude quota. Approvals are # continuations of an in-progress turn — the session was already charged @@ -594,7 +688,7 @@ async def chat_sse( ] success = await session_manager.submit_approval(session_id, formatted) elif text is not None: - success = await session_manager.submit_user_input(session_id, text) + success = await session_manager.submit_user_input(session_id, text, uploads) else: broadcaster.unsubscribe(sub_id) raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") diff --git a/backend/session_manager.py b/backend/session_manager.py index 91650859..c85770fa 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -753,9 +753,14 @@ async def submit(self, session_id: str, operation: Operation) -> bool: await agent_session.submission_queue.put(submission) return True - async def submit_user_input(self, session_id: str, text: str) -> bool: + async def submit_user_input( + self, session_id: str, text: str, attachments: list[dict[str, Any]] | None = None + ) -> bool: """Submit user input to a session.""" - operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) + data: dict[str, Any] = {"text": text} + if attachments: + data["attachments"] = attachments + operation = Operation(op_type=OpType.USER_INPUT, data=data) return await self.submit(session_id, operation) async def submit_approval( diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 58e253c1..dde115c9 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -3,6 +3,9 @@ import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuIte import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; +import AttachFileIcon from '@mui/icons-material/AttachFile'; +import CloseIcon from '@mui/icons-material/Close'; +import StorageIcon from '@mui/icons-material/Storage'; import { apiFetch } from '@/utils/api'; import { useUserQuota } from '@/hooks/useUserQuota'; import ClaudeCapDialog from '@/components/ClaudeCapDialog'; @@ -64,21 +67,32 @@ const findModelByPath = (path: string): ModelOption | undefined => { interface ChatInputProps { sessionId?: string; - onSend: (text: string) => void; + onSend: (text: string, uploads?: unknown[]) => void; onStop?: () => void; isProcessing?: boolean; disabled?: boolean; placeholder?: string; } +interface PendingRetry { + inputText: string; + displayText: string; + uploads: unknown[]; +} + const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath); const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); + const fileInputRef = useRef(null); const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); + const [selectedFiles, setSelectedFiles] = useState>([]); + const [isUploading, setIsUploading] = useState(false); + const [uploadError, setUploadError] = useState(null); + const [importAsDataset, setImportAsDataset] = useState(false); const { quota, refresh: refreshQuota } = useUserQuota(); // The daily-cap dialog is triggered from two places: (a) a 429 returned // from the chat transport when the user tries to send on Opus over cap — @@ -90,7 +104,7 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa const jobsUpgradeRequired = useAgentStore((s) => s.jobsUpgradeRequired); const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired); const [awaitingTopUp, setAwaitingTopUp] = useState(false); - const lastSentRef = useRef(''); + const lastSentRef = useRef(null); // Model is per-session: fetch this tab's current model every time the // session changes. Other tabs keep their own selections independently. @@ -119,19 +133,60 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa } }, [disabled, isProcessing]); - const handleSend = useCallback(() => { - if (input.trim() && !disabled) { - lastSentRef.current = input; - onSend(input); - setInput(''); + const addFiles = useCallback((files: FileList | File[]) => { + const next = Array.from(files).map((file) => ({ + id: `${file.name}-${file.lastModified}-${Math.random().toString(16).slice(2)}`, + file, + })); + setSelectedFiles((current) => [...current, ...next]); + setUploadError(null); + }, []); + + const uploadSelectedFiles = useCallback(async (): Promise => { + if (!sessionId || selectedFiles.length === 0) return []; + const form = new FormData(); + selectedFiles.forEach(({ file }) => form.append('files', file, file.name)); + form.append('import_as_dataset', String(importAsDataset)); + const res = await apiFetch(`/api/session/${sessionId}/uploads`, { + method: 'POST', + body: form, + }); + if (!res.ok) { + const detail = await res.text().catch(() => 'Upload failed'); + throw new Error(detail || 'Upload failed'); } - }, [input, disabled, onSend]); + const data = await res.json(); + return data?.upload ? [data.upload] : []; + }, [sessionId, selectedFiles, importAsDataset]); + + const handleSend = useCallback(async () => { + if ((input.trim() || selectedFiles.length > 0) && !disabled && !isUploading) { + setUploadError(null); + setIsUploading(true); + const baseText = input.trim(); + try { + const uploads = await uploadSelectedFiles(); + const placeholders = selectedFiles + .map(({ file }, index) => `[${file.type.startsWith('image/') ? 'Image' : 'File'} #${index + 1}] ${file.name}`) + .join('\n'); + const displayText = placeholders ? `${baseText}\n\n${placeholders}`.trim() : baseText; + lastSentRef.current = { inputText: baseText, displayText, uploads }; + onSend(displayText, uploads); + setInput(''); + setSelectedFiles([]); + } catch (err) { + setUploadError(err instanceof Error ? err.message : 'Upload failed'); + } finally { + setIsUploading(false); + } + } + }, [input, selectedFiles, disabled, isUploading, uploadSelectedFiles, onSend]); // When the chat transport reports a Claude-quota 429, restore the typed // text so the user doesn't lose their message. useEffect(() => { if (claudeQuotaExhausted && lastSentRef.current) { - setInput(lastSentRef.current); + setInput(lastSentRef.current.inputText); } }, [claudeQuotaExhausted]); @@ -191,11 +246,11 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa }); if (res.ok) { setSelectedModelId(free.id); - const retryText = lastSentRef.current; - if (retryText) { - onSend(retryText); + const retry = lastSentRef.current; + if (retry) { + onSend(retry.displayText, retry.uploads); setInput(''); - lastSentRef.current = ''; + lastSentRef.current = null; } } } catch { /* ignore */ } @@ -278,8 +333,9 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa className="composer" sx={{ display: 'flex', + flexDirection: 'column', gap: '10px', - alignItems: 'flex-start', + alignItems: 'stretch', bgcolor: 'var(--composer-bg)', borderRadius: 'var(--radius-md)', p: '12px', @@ -290,7 +346,79 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa boxShadow: 'var(--focus)', } }} + onDragOver={(e) => { + e.preventDefault(); + }} + onDrop={(e) => { + e.preventDefault(); + if (!disabled && !isProcessing && e.dataTransfer.files.length > 0) { + addFiles(e.dataTransfer.files); + } + }} > + { + if (e.target.files) addFiles(e.target.files); + e.target.value = ''; + }} + /> + {(selectedFiles.length > 0 || uploadError) && ( + + {selectedFiles.map(({ id, file }) => ( + setSelectedFiles((files) => files.filter((f) => f.id !== id))} + deleteIcon={} + sx={{ + maxWidth: '220px', + bgcolor: 'rgba(255,255,255,0.06)', + color: 'var(--text)', + border: '1px solid var(--divider)', + '& .MuiChip-label': { overflow: 'hidden', textOverflow: 'ellipsis' }, + }} + /> + ))} + {selectedFiles.length > 0 && ( + } + label={importAsDataset ? 'Import as dataset' : 'Attach to turn'} + onClick={() => setImportAsDataset((v) => !v)} + sx={{ + bgcolor: importAsDataset ? 'var(--accent-yellow)' : 'transparent', + color: importAsDataset ? '#000' : 'var(--muted-text)', + border: '1px solid var(--divider)', + fontWeight: 600, + }} + /> + )} + {uploadError && ( + + {uploadError} + + )} + + )} + + fileInputRef.current?.click()} + disabled={disabled || isProcessing || isUploading} + sx={{ + mt: 1, + p: 1, + borderRadius: '10px', + color: 'var(--muted-text)', + '&:hover': { color: 'var(--accent-yellow)', bgcolor: 'var(--hover-bg)' }, + }} + > + + )} + {/* Powered By Badge */} diff --git a/frontend/src/components/SessionChat.tsx b/frontend/src/components/SessionChat.tsx index 8f182380..ba0afccd 100644 --- a/frontend/src/components/SessionChat.tsx +++ b/frontend/src/components/SessionChat.tsx @@ -68,11 +68,14 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess const busy = isProcessing || sdkBusy; const handleSendMessage = useCallback( - async (text: string) => { + async (text: string, uploads?: unknown[]) => { if (!text.trim() || busy) return; updateSession(sessionId, { isProcessing: true, activityStatus: { type: 'thinking' } }); - sendMessage({ text: text.trim(), metadata: { createdAt: new Date().toISOString() } }); + sendMessage({ + text: text.trim(), + metadata: { createdAt: new Date().toISOString(), uploads }, + }); // Auto-title the session from the first user message const isFirstMessage = messages.filter((m) => m.role === 'user').length === 0; diff --git a/frontend/src/lib/sse-chat-transport.ts b/frontend/src/lib/sse-chat-transport.ts index 77f85189..186a86b0 100644 --- a/frontend/src/lib/sse-chat-transport.ts +++ b/frontend/src/lib/sse-chat-transport.ts @@ -382,7 +382,10 @@ export class SSEChatTransport implements ChatTransport { .map(p => p.text) .join('') : ''; - body = { text }; + const uploads = lastUserMsg?.metadata && typeof lastUserMsg.metadata === 'object' + ? (lastUserMsg.metadata as Record).uploads + : undefined; + body = uploads ? { text, uploads } : { text }; } // POST to SSE endpoint diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts index 4dc72074..71522205 100644 --- a/frontend/src/utils/api.ts +++ b/frontend/src/utils/api.ts @@ -12,8 +12,9 @@ export async function apiFetch( path: string, options: RequestInit = {} ): Promise { + const isFormData = options.body instanceof FormData; const headers: Record = { - 'Content-Type': 'application/json', + ...(isFormData ? {} : { 'Content-Type': 'application/json' }), ...(options.headers as Record), }; @@ -38,4 +39,4 @@ export async function apiFetch( } return response; -} \ No newline at end of file +} diff --git a/tests/unit/test_attachments.py b/tests/unit/test_attachments.py new file mode 100644 index 00000000..47c93db6 --- /dev/null +++ b/tests/unit/test_attachments.py @@ -0,0 +1,90 @@ +from pathlib import Path + +from agent.core.attachments import ( + AttachmentSource, + build_user_content, + create_context_manifest, + import_dataset_batch, + sanitize_filename, +) +from agent.main import _split_path_args + + +class FakeHfApi: + def __init__(self): + self.created = [] + self.uploads = [] + + def whoami(self, token=None): + return {"name": "alice"} + + def create_repo(self, **kwargs): + self.created.append(kwargs) + + def upload_file(self, **kwargs): + self.uploads.append(kwargs) + + +def test_sanitize_filename_keeps_safe_basename(): + assert sanitize_filename("../../my data?.csv") == "my_data_.csv" + assert sanitize_filename(" ") == "attachment" + + +def test_context_manifest_includes_metadata_and_text_preview(tmp_path: Path): + path = tmp_path / "rows.csv" + path.write_text("a,b\n1,2\n", encoding="utf-8") + + manifest = create_context_manifest( + [AttachmentSource(path=path)], + scope_id="session-1", + ) + + assert manifest["type"] == "context_upload" + item = manifest["items"][0] + assert item["filename"] == "rows.csv" + assert item["mime_type"] in {"text/csv", "application/vnd.ms-excel"} + assert item["text_preview"] == "a,b\n1,2\n" + assert item["placeholder"] == "[File #1]" + + +def test_build_user_content_adds_transient_image_part(tmp_path: Path): + image = tmp_path / "image.png" + image.write_bytes(b"\x89PNG\r\n\x1a\n") + manifest = create_context_manifest( + [AttachmentSource(path=image, kind="image")], + scope_id="session-1", + ) + + content = build_user_content("describe this", [manifest]) + + assert isinstance(content, list) + assert content[0]["type"] == "text" + assert content[1]["type"] == "image_url" + assert content[1]["image_url"]["url"].startswith("data:image/png;base64,") + + +def test_import_dataset_batch_uploads_files_and_manifest(tmp_path: Path): + data = tmp_path / "train.jsonl" + data.write_text('{"text":"hi"}\n', encoding="utf-8") + api = FakeHfApi() + + manifest = import_dataset_batch( + [AttachmentSource(path=data)], + token="hf_test", + scope_id="run-1", + upload_id="upload-1", + api=api, + ) + + assert api.created[0]["repo_id"] == "alice/ml-intern-user-datasets" + assert api.created[0]["private"] is True + assert manifest["repo_id"] == "alice/ml-intern-user-datasets" + assert manifest["path_prefix"] == "sessions/run-1/upload-1" + uploaded_paths = {upload["path_in_repo"] for upload in api.uploads} + assert "sessions/run-1/upload-1/files/001-train.jsonl" in uploaded_paths + assert "sessions/run-1/upload-1/manifest.json" in uploaded_paths + + +def test_cli_path_args_accept_repeated_and_comma_delimited(): + assert _split_path_args(["a.csv,b.csv", "c.png"]) == ["a.csv", "b.csv", "c.png"] +