diff --git a/web-demos/web_app/app.py b/web-demos/web_app/app.py new file mode 100644 index 00000000..bd9f781a --- /dev/null +++ b/web-demos/web_app/app.py @@ -0,0 +1,495 @@ +import sys +sys.path.append("../../") +sys.path.append("../hugging_face/") + +import os +import io +import json +import time +import uuid +import psutil +import base64 +import argparse +import tempfile + +import cv2 +import torch +import torchvision +import numpy as np +from PIL import Image +from flask import Flask, render_template, request, jsonify, send_file, session + +from tools.painter import mask_painter, point_painter +from track_anything import TrackingAnything +from model.misc import get_device +from utils.download_util import load_file_from_url + +app = Flask(__name__) +app.secret_key = os.urandom(24) + +# ── Global state per session (simple single-user version) ────────────────── +sessions = {} + + +def get_session(): + sid = session.get("sid") + if sid and sid in sessions: + return sessions[sid] + return None + + +def create_session(): + sid = str(uuid.uuid4()) + session["sid"] = sid + sessions[sid] = { + "video_state": { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30, + }, + "interactive_state": { + "inference_times": 0, + "negative_click_times": 0, + "positive_click_times": 0, + "mask_save": False, + "multi_mask": {"mask_names": [], "masks": []}, + "track_end_number": None, + }, + "click_state": [[], []], + } + return sessions[sid] + + +# ── Model initialisation ────────────────────────────────────────────────── +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--sam_model_type", type=str, default="vit_h") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--mask_save", default=False, action="store_true") + args = parser.parse_args() + if not args.device: + args.device = str(get_device()) + return args + + +args = parse_args() + +pretrain_model_url = "https://github.com/sczhou/ProPainter/releases/download/v0.1.0/" +sam_checkpoint_url_dict = { + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", +} +checkpoint_folder = os.path.join("..", "..", "weights") + +sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder) +cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, "cutie-base-mega.pth"), checkpoint_folder) +propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, "ProPainter.pth"), checkpoint_folder) +raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, "raft-things.pth"), checkpoint_folder) +flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, "recurrent_flow_completion.pth"), checkpoint_folder) + +model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args) + + +# ── Helpers ──────────────────────────────────────────────────────────────── +def numpy_to_base64(img_array): + """Convert a numpy RGB image to a base64 JPEG string.""" + if isinstance(img_array, Image.Image): + img_array = np.array(img_array) + img_bgr = cv2.cvtColor(img_array.astype(np.uint8), cv2.COLOR_RGB2BGR) + _, buffer = cv2.imencode(".jpg", img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85]) + return base64.b64encode(buffer).decode("utf-8") + + +def generate_video_from_frames(frames, output_path, fps=30): + frames_tensor = torch.from_numpy(np.asarray(frames)) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + torchvision.io.write_video(output_path, frames_tensor, fps=fps, video_codec="libx264") + return output_path + + +# ── Routes ───────────────────────────────────────────────────────────────── +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/api/upload", methods=["POST"]) +def upload_video(): + """Upload a video, extract frames, initialise session state.""" + if "video" not in request.files: + return jsonify({"error": "No video file provided"}), 400 + + sess = create_session() + video_file = request.files["video"] + + # Save to temp file + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") + video_file.save(tmp.name) + tmp.close() + + frames = [] + try: + cap = cv2.VideoCapture(tmp.name) + fps = cap.get(cv2.CAP_PROP_FPS) + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + if psutil.virtual_memory().percent > 90: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + os.unlink(tmp.name) + + if len(frames) == 0: + return jsonify({"error": "Could not extract any frames from the video"}), 400 + + h, w = frames[0].shape[:2] + sess["video_state"] = { + "user_name": time.time(), + "video_name": video_file.filename or "video.mp4", + "origin_images": frames, + "painted_images": [f.copy() for f in frames], + "masks": [np.zeros((h, w), np.uint8) for _ in range(len(frames))], + "logits": [None] * len(frames), + "select_frame_number": 0, + "fps": fps, + } + + # Set first frame for SAM + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(frames[0]) + + return jsonify({ + "total_frames": len(frames), + "fps": round(fps, 2), + "width": w, + "height": h, + "video_name": video_file.filename, + "first_frame": numpy_to_base64(frames[0]), + }) + + +@app.route("/api/frame/") +def get_frame(frame_num): + """Return a specific frame as base64 JPEG.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + vs = sess["video_state"] + if vs["origin_images"] is None: + return jsonify({"error": "No video loaded"}), 400 + idx = max(0, min(frame_num, len(vs["origin_images"]) - 1)) + return jsonify({"frame": numpy_to_base64(vs["painted_images"][idx]), "index": idx}) + + +@app.route("/api/select-frame", methods=["POST"]) +def select_frame(): + """Select a template frame for mask painting.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + data = request.json + frame_idx = data.get("frame", 0) + vs = sess["video_state"] + frame_idx = max(0, min(frame_idx, len(vs["origin_images"]) - 1)) + vs["select_frame_number"] = frame_idx + + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(vs["origin_images"][frame_idx]) + + # Reset click state for new frame + sess["click_state"] = [[], []] + + return jsonify({ + "frame": numpy_to_base64(vs["painted_images"][frame_idx]), + "index": frame_idx, + }) + + +@app.route("/api/set-end-frame", methods=["POST"]) +def set_end_frame(): + """Set the tracking end frame.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + data = request.json + end_frame = data.get("frame") + sess["interactive_state"]["track_end_number"] = end_frame + return jsonify({"end_frame": end_frame}) + + +@app.route("/api/sam-click", methods=["POST"]) +def sam_click(): + """Process a click on the frame for SAM segmentation.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + data = request.json + x = data["x"] + y = data["y"] + is_positive = data.get("positive", True) + + vs = sess["video_state"] + click_state = sess["click_state"] + + click_state[0].append([x, y]) + click_state[1].append(1 if is_positive else 0) + + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(vs["origin_images"][vs["select_frame_number"]]) + + prompt = { + "prompt_type": ["click"], + "input_point": click_state[0], + "input_label": click_state[1], + "multimask_output": "True", + } + + mask, logit, painted_image = model.first_frame_click( + image=vs["origin_images"][vs["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=prompt["multimask_output"], + ) + + vs["masks"][vs["select_frame_number"]] = mask + vs["logits"][vs["select_frame_number"]] = logit + vs["painted_images"][vs["select_frame_number"]] = np.array(painted_image) + + return jsonify({"frame": numpy_to_base64(painted_image)}) + + +@app.route("/api/add-mask", methods=["POST"]) +def add_mask(): + """Add current mask to the multi-mask set.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + vs = sess["video_state"] + ist = sess["interactive_state"] + + mask = vs["masks"][vs["select_frame_number"]] + mask_name = "mask_{:03d}".format(len(ist["multi_mask"]["masks"]) + 1) + ist["multi_mask"]["masks"].append(mask.copy()) + ist["multi_mask"]["mask_names"].append(mask_name) + + # Repaint frame with all masks + select_frame = vs["origin_images"][vs["select_frame_number"]].copy() + for i, m in enumerate(ist["multi_mask"]["masks"]): + select_frame = mask_painter(select_frame, m.astype("uint8"), mask_color=i + 2) + + vs["painted_images"][vs["select_frame_number"]] = select_frame + + # Reset click state + sess["click_state"] = [[], []] + + return jsonify({ + "frame": numpy_to_base64(select_frame), + "masks": ist["multi_mask"]["mask_names"], + }) + + +@app.route("/api/remove-masks", methods=["POST"]) +def remove_masks(): + """Remove all masks.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + ist = sess["interactive_state"] + vs = sess["video_state"] + ist["multi_mask"]["masks"] = [] + ist["multi_mask"]["mask_names"] = [] + + # Restore original frame + frame_idx = vs["select_frame_number"] + vs["painted_images"][frame_idx] = vs["origin_images"][frame_idx].copy() + vs["masks"][frame_idx] = np.zeros_like(vs["masks"][frame_idx]) + + return jsonify({ + "frame": numpy_to_base64(vs["origin_images"][frame_idx]), + "masks": [], + }) + + +@app.route("/api/clear-clicks", methods=["POST"]) +def clear_clicks(): + """Clear current click state and restore frame.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + sess["click_state"] = [[], []] + vs = sess["video_state"] + frame_idx = vs["select_frame_number"] + + # Repaint with existing multi-masks only + ist = sess["interactive_state"] + select_frame = vs["origin_images"][frame_idx].copy() + for i, m in enumerate(ist["multi_mask"]["masks"]): + select_frame = mask_painter(select_frame, m.astype("uint8"), mask_color=i + 2) + + vs["painted_images"][frame_idx] = select_frame + + return jsonify({"frame": numpy_to_base64(select_frame)}) + + +@app.route("/api/track", methods=["POST"]) +def track_video(): + """Run VOS tracking on the video.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + vs = sess["video_state"] + ist = sess["interactive_state"] + + mask_selection = request.json.get("masks", ist["multi_mask"]["mask_names"]) + if not mask_selection: + mask_selection = ["mask_001"] if ist["multi_mask"]["masks"] else [] + + if not mask_selection: + return jsonify({"error": "No masks to track. Please add at least one mask."}), 400 + + model.cutie.clear_memory() + + if ist["track_end_number"]: + following_frames = vs["origin_images"][vs["select_frame_number"]:ist["track_end_number"]] + else: + following_frames = vs["origin_images"][vs["select_frame_number"]:] + + # Build template mask + mask_selection_sorted = sorted(mask_selection) + template_mask = ist["multi_mask"]["masks"][int(mask_selection_sorted[0].split("_")[1]) - 1] * int(mask_selection_sorted[0].split("_")[1]) + for i in range(1, len(mask_selection_sorted)): + mask_number = int(mask_selection_sorted[i].split("_")[1]) - 1 + template_mask = np.clip( + template_mask + ist["multi_mask"]["masks"][mask_number] * (mask_number + 1), + 0, mask_number + 1, + ) + + vs["masks"][vs["select_frame_number"]] = template_mask + + if len(np.unique(template_mask)) == 1: + template_mask[0][0] = 1 + + masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) + model.cutie.clear_memory() + + if ist["track_end_number"]: + vs["masks"][vs["select_frame_number"]:ist["track_end_number"]] = masks + vs["logits"][vs["select_frame_number"]:ist["track_end_number"]] = logits + vs["painted_images"][vs["select_frame_number"]:ist["track_end_number"]] = painted_images + else: + vs["masks"][vs["select_frame_number"]:] = masks + vs["logits"][vs["select_frame_number"]:] = logits + vs["painted_images"][vs["select_frame_number"]:] = painted_images + + # Generate tracking output video + output_dir = os.path.join(".", "result", "track") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, vs["video_name"]) + painted_np = [np.array(img) if isinstance(img, Image.Image) else img for img in vs["painted_images"]] + video_path = generate_video_from_frames(painted_np, output_path, fps=vs["fps"]) + + sess["track_video_path"] = os.path.abspath(video_path) + ist["inference_times"] += 1 + + return jsonify({"status": "ok", "video_url": "/api/result/track"}) + + +@app.route("/api/inpaint", methods=["POST"]) +def inpaint_video(): + """Run ProPainter inpainting.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + vs = sess["video_state"] + ist = sess["interactive_state"] + data = request.json or {} + + resize_ratio = data.get("resize_ratio", 1.0) + dilate_radius = data.get("dilate_radius", 8) + raft_iter = data.get("raft_iter", 20) + subvideo_length = data.get("subvideo_length", 80) + neighbor_length = data.get("neighbor_length", 10) + ref_stride = data.get("ref_stride", 10) + mask_selection = data.get("masks", ist["multi_mask"]["mask_names"]) + + if not mask_selection: + mask_selection = ["mask_001"] + + frames = np.asarray(vs["origin_images"]) + inpaint_masks = np.asarray(vs["masks"]) + + mask_selection_sorted = sorted(mask_selection) + inpaint_mask_numbers = [int(m.split("_")[1]) for m in mask_selection_sorted] + unique_masks = np.unique(inpaint_masks) + num_masks = len(unique_masks) - 1 + for i in range(1, num_masks + 1): + if i in inpaint_mask_numbers: + continue + inpaint_masks[inpaint_masks == i] = 0 + + inpainted_frames = model.baseinpainter.inpaint( + frames, inpaint_masks, + ratio=resize_ratio, + dilate_radius=dilate_radius, + raft_iter=raft_iter, + subvideo_length=subvideo_length, + neighbor_length=neighbor_length, + ref_stride=ref_stride, + ) + + output_dir = os.path.join(".", "result", "inpaint") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, vs["video_name"]) + video_path = generate_video_from_frames(inpainted_frames, output_path, fps=vs["fps"]) + sess["inpaint_video_path"] = os.path.abspath(video_path) + + return jsonify({"status": "ok", "video_url": "/api/result/inpaint"}) + + +@app.route("/api/result/") +def get_result(result_type): + """Stream the result video.""" + sess = get_session() + if not sess: + return jsonify({"error": "No session"}), 400 + + key = f"{result_type}_video_path" + path = sess.get(key) + if not path or not os.path.exists(path): + return jsonify({"error": "No result available"}), 404 + + return send_file(path, mimetype="video/mp4", as_attachment=False) + + +@app.route("/api/reset", methods=["POST"]) +def reset_session(): + """Reset the current session.""" + sid = session.get("sid") + if sid and sid in sessions: + del sessions[sid] + return jsonify({"status": "ok"}) + + +# ── Main ─────────────────────────────────────────────────────────────────── +if __name__ == "__main__": + app.run(host="0.0.0.0", port=args.port, debug=True) diff --git a/web-demos/web_app/requirements.txt b/web-demos/web_app/requirements.txt new file mode 100644 index 00000000..d19207eb --- /dev/null +++ b/web-demos/web_app/requirements.txt @@ -0,0 +1,11 @@ +flask +numpy +opencv-python +torch>=1.7.1 +torchvision>=0.8.2 +Pillow +psutil +tqdm +scipy +pyyaml +segment-anything @ git+https://github.com/facebookresearch/segment-anything.git diff --git a/web-demos/web_app/static/css/style.css b/web-demos/web_app/static/css/style.css new file mode 100644 index 00000000..cb8d0125 --- /dev/null +++ b/web-demos/web_app/static/css/style.css @@ -0,0 +1,528 @@ +/* ── Reset & Base ──────────────────────────────────────────────────────── */ +*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + +:root { + --bg: #0f1117; + --bg-card: #1a1d27; + --bg-elevated: #232730; + --border: #2d3140; + --text: #e4e6eb; + --text-muted: #8b8fa3; + --primary: #6366f1; + --primary-hover: #818cf8; + --success: #22c55e; + --success-hover: #16a34a; + --danger: #ef4444; + --danger-hover: #dc2626; + --radius: 10px; + --radius-sm: 6px; + --shadow: 0 2px 8px rgba(0,0,0,0.3); +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + background: var(--bg); + color: var(--text); + line-height: 1.6; + min-height: 100vh; +} + +/* ── Header ────────────────────────────────────────────────────────────── */ +header { + text-align: center; + padding: 2rem 1rem 1rem; + border-bottom: 1px solid var(--border); +} + +header h1 { + font-size: 2rem; + font-weight: 700; + background: linear-gradient(135deg, var(--primary), #a78bfa); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + background-clip: text; +} + +header .subtitle { + color: var(--text-muted); + font-size: 0.95rem; + margin-top: 0.25rem; +} + +/* ── Main ──────────────────────────────────────────────────────────────── */ +main { + max-width: 1100px; + margin: 0 auto; + padding: 1.5rem 1rem 4rem; +} + +/* ── Steps ─────────────────────────────────────────────────────────────── */ +.step { + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: var(--radius); + margin-bottom: 1.25rem; + overflow: hidden; + transition: opacity 0.3s; +} + +.step.disabled { + opacity: 0.5; + pointer-events: none; +} + +.step-header { + display: flex; + align-items: center; + gap: 0.75rem; + padding: 1rem 1.25rem; + border-bottom: 1px solid var(--border); +} + +.step-number { + display: inline-flex; + align-items: center; + justify-content: center; + width: 28px; + height: 28px; + border-radius: 50%; + background: var(--primary); + color: #fff; + font-weight: 700; + font-size: 0.85rem; + flex-shrink: 0; +} + +.step-header h2 { + font-size: 1.1rem; + font-weight: 600; +} + +.step-content { + padding: 1.25rem; +} + +/* ── Upload Area ───────────────────────────────────────────────────────── */ +.upload-area { + border: 2px dashed var(--border); + border-radius: var(--radius); + padding: 3rem 2rem; + text-align: center; + cursor: pointer; + transition: border-color 0.2s, background 0.2s; +} + +.upload-area:hover, +.upload-area.dragover { + border-color: var(--primary); + background: rgba(99, 102, 241, 0.05); +} + +.upload-prompt svg { + color: var(--text-muted); + margin-bottom: 0.75rem; +} + +.upload-prompt p { + font-size: 1rem; + font-weight: 500; + margin-bottom: 0.25rem; +} + +.upload-hint { + color: var(--text-muted); + font-size: 0.85rem; +} + +/* ── Video Info ─────────────────────────────────────────────────────────── */ +.video-info { + margin-top: 1rem; +} + +.info-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); + gap: 0.75rem; + margin-bottom: 1rem; +} + +.info-item { + background: var(--bg-elevated); + border-radius: var(--radius-sm); + padding: 0.75rem; +} + +.info-label { + display: block; + font-size: 0.75rem; + color: var(--text-muted); + text-transform: uppercase; + letter-spacing: 0.05em; + margin-bottom: 0.25rem; +} + +.info-item span:last-child { + font-weight: 600; + font-size: 0.95rem; +} + +/* ── Buttons ───────────────────────────────────────────────────────────── */ +.btn { + display: inline-flex; + align-items: center; + justify-content: center; + gap: 0.5rem; + padding: 0.55rem 1.1rem; + border: none; + border-radius: var(--radius-sm); + font-size: 0.9rem; + font-weight: 500; + cursor: pointer; + transition: background 0.15s, transform 0.1s; + color: #fff; +} + +.btn:active { transform: scale(0.97); } +.btn:disabled { opacity: 0.45; cursor: not-allowed; } + +.btn-primary { background: var(--primary); } +.btn-primary:hover:not(:disabled) { background: var(--primary-hover); } + +.btn-success { background: var(--success); } +.btn-success:hover:not(:disabled) { background: var(--success-hover); } + +.btn-danger { background: var(--danger); } +.btn-danger:hover:not(:disabled) { background: var(--danger-hover); } + +.btn-secondary { background: var(--bg-elevated); border: 1px solid var(--border); color: var(--text); } +.btn-secondary:hover:not(:disabled) { background: var(--border); } + +.btn-lg { padding: 0.75rem 1.5rem; font-size: 1rem; } + +/* ── Canvas & Mask Workspace ───────────────────────────────────────────── */ +.mask-workspace { + display: grid; + grid-template-columns: 1fr 280px; + gap: 1.25rem; +} + +@media (max-width: 768px) { + .mask-workspace { + grid-template-columns: 1fr; + } +} + +.canvas-container { + position: relative; + background: #000; + border-radius: var(--radius); + overflow: hidden; + aspect-ratio: 16/9; + display: flex; + align-items: center; + justify-content: center; +} + +#frame-canvas { + max-width: 100%; + max-height: 100%; + cursor: crosshair; + display: block; +} + +.canvas-overlay { + position: absolute; + inset: 0; + display: flex; + align-items: center; + justify-content: center; + background: rgba(0, 0, 0, 0.6); + color: var(--text-muted); + font-size: 0.95rem; +} + +.canvas-overlay.hidden { display: none; } + +/* ── Mask Controls ─────────────────────────────────────────────────────── */ +.mask-controls { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.control-group label { + display: block; + font-size: 0.8rem; + color: var(--text-muted); + text-transform: uppercase; + letter-spacing: 0.04em; + margin-bottom: 0.4rem; +} + +.slider-row { + display: flex; + align-items: center; + gap: 0.75rem; +} + +.slider-row input[type="range"] { + flex: 1; + -webkit-appearance: none; + appearance: none; + height: 4px; + background: var(--border); + border-radius: 2px; + outline: none; +} + +.slider-row input[type="range"]::-webkit-slider-thumb { + -webkit-appearance: none; + width: 16px; + height: 16px; + border-radius: 50%; + background: var(--primary); + cursor: pointer; + border: 2px solid var(--bg-card); +} + +.slider-value { + min-width: 36px; + text-align: right; + font-size: 0.85rem; + font-weight: 600; + color: var(--text-muted); +} + +/* Toggle group */ +.toggle-group { + display: flex; + gap: 0.5rem; +} + +.toggle-btn { + flex: 1; + display: inline-flex; + align-items: center; + justify-content: center; + gap: 0.35rem; + padding: 0.45rem 0.5rem; + border: 1px solid var(--border); + border-radius: var(--radius-sm); + background: var(--bg-elevated); + color: var(--text-muted); + font-size: 0.85rem; + cursor: pointer; + transition: all 0.15s; +} + +.toggle-btn:disabled { opacity: 0.45; cursor: not-allowed; } +.toggle-btn.active { background: var(--primary); border-color: var(--primary); color: #fff; } + +/* Button stack */ +.button-stack { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.button-stack .btn { width: 100%; } + +/* Mask list */ +.mask-list { + background: var(--bg-elevated); + border-radius: var(--radius-sm); + padding: 0.5rem 0.75rem; + min-height: 40px; + display: flex; + flex-wrap: wrap; + gap: 0.4rem; + align-items: center; +} + +.mask-list .empty-state { + color: var(--text-muted); + font-size: 0.85rem; +} + +.mask-tag { + display: inline-flex; + align-items: center; + gap: 0.3rem; + padding: 0.2rem 0.6rem; + border-radius: 20px; + font-size: 0.8rem; + font-weight: 500; + color: #fff; +} + +.mask-tag.color-0 { background: #6366f1; } +.mask-tag.color-1 { background: #f59e0b; } +.mask-tag.color-2 { background: #10b981; } +.mask-tag.color-3 { background: #ef4444; } +.mask-tag.color-4 { background: #8b5cf6; } +.mask-tag.color-5 { background: #ec4899; } + +/* ── Parameters Panel ──────────────────────────────────────────────────── */ +.params-panel { + background: var(--bg-elevated); + border: 1px solid var(--border); + border-radius: var(--radius); + margin-bottom: 1.25rem; +} + +.params-panel summary { + padding: 0.75rem 1rem; + cursor: pointer; + font-weight: 500; + color: var(--text-muted); + user-select: none; +} + +.params-panel summary:hover { color: var(--text); } + +.params-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); + gap: 1rem; + padding: 0 1rem 1rem; +} + +.param-item label { + display: block; + font-size: 0.8rem; + color: var(--text-muted); + margin-bottom: 0.35rem; +} + +/* ── Action Row ────────────────────────────────────────────────────────── */ +.action-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1.25rem; +} + +@media (max-width: 640px) { + .action-row { grid-template-columns: 1fr; } +} + +.action-card { + background: var(--bg-elevated); + border: 1px solid var(--border); + border-radius: var(--radius); + padding: 1.25rem; + text-align: center; +} + +.action-card h3 { + font-size: 1rem; + margin-bottom: 0.4rem; +} + +.action-card p { + font-size: 0.85rem; + color: var(--text-muted); + margin-bottom: 1rem; +} + +.result-video { + margin-top: 1rem; +} + +.result-video video { + width: 100%; + border-radius: var(--radius-sm); + background: #000; +} + +/* ── Status Bar ────────────────────────────────────────────────────────── */ +.status-bar { + position: fixed; + bottom: 0; + left: 0; + right: 0; + background: var(--bg-card); + border-top: 1px solid var(--border); + padding: 0.6rem 1.5rem; + display: flex; + align-items: center; + justify-content: space-between; + z-index: 100; + font-size: 0.85rem; +} + +.status-message { + color: var(--text-muted); +} + +.status-message.error { + color: var(--danger); +} + +.status-message.success { + color: var(--success); +} + +/* Spinner */ +.spinner { + width: 18px; + height: 18px; + border: 2px solid var(--border); + border-top-color: var(--primary); + border-radius: 50%; + animation: spin 0.6s linear infinite; +} + +@keyframes spin { + to { transform: rotate(360deg); } +} + +/* ── Footer ────────────────────────────────────────────────────────────── */ +footer { + text-align: center; + padding: 1.5rem; + color: var(--text-muted); + font-size: 0.85rem; + border-top: 1px solid var(--border); + margin-top: 2rem; +} + +footer a { + color: var(--primary); + text-decoration: none; +} + +footer a:hover { text-decoration: underline; } + +/* ── Loading overlay for long operations ───────────────────────────────── */ +.loading-overlay { + position: fixed; + inset: 0; + background: rgba(15, 17, 23, 0.85); + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + z-index: 200; + gap: 1rem; +} + +.loading-overlay .spinner-lg { + width: 40px; + height: 40px; + border: 3px solid var(--border); + border-top-color: var(--primary); + border-radius: 50%; + animation: spin 0.7s linear infinite; +} + +.loading-overlay p { + color: var(--text); + font-size: 1rem; + font-weight: 500; +} + +/* ── Scrollbar ─────────────────────────────────────────────────────────── */ +::-webkit-scrollbar { width: 8px; } +::-webkit-scrollbar-track { background: var(--bg); } +::-webkit-scrollbar-thumb { background: var(--border); border-radius: 4px; } +::-webkit-scrollbar-thumb:hover { background: var(--text-muted); } diff --git a/web-demos/web_app/static/js/app.js b/web-demos/web_app/static/js/app.js new file mode 100644 index 00000000..da4ef84a --- /dev/null +++ b/web-demos/web_app/static/js/app.js @@ -0,0 +1,442 @@ +// ── State ────────────────────────────────────────────────────────────── +const state = { + totalFrames: 0, + currentFrame: 0, + endFrame: 0, + isPositive: true, + masks: [], + imageWidth: 0, + imageHeight: 0, + canvasScale: 1, + busy: false, +}; + +// ── DOM Elements ────────────────────────────────────────────────────── +const $ = (sel) => document.querySelector(sel); +const uploadArea = $("#upload-area"); +const uploadPrompt = $("#upload-prompt"); +const videoInput = $("#video-input"); +const videoInfo = $("#video-info"); +const canvas = $("#frame-canvas"); +const ctx = canvas.getContext("2d"); +const canvasOverlay = $("#canvas-overlay"); + +const startSlider = $("#start-frame-slider"); +const endSlider = $("#end-frame-slider"); +const startLabel = $("#start-frame-label"); +const endLabel = $("#end-frame-label"); + +const btnPositive = $("#btn-positive"); +const btnNegative = $("#btn-negative"); +const btnAddMask = $("#btn-add-mask"); +const btnClearClicks = $("#btn-clear-clicks"); +const btnRemoveMasks = $("#btn-remove-masks"); +const btnTrack = $("#btn-track"); +const btnInpaint = $("#btn-inpaint"); +const btnReset = $("#btn-reset"); + +const maskList = $("#mask-list"); +const statusMessage = $("#status-message"); +const spinner = $("#spinner"); + +const step2 = $("#step2"); +const step3 = $("#step3"); + +// Parameter elements +const paramSliders = { + resize: { slider: $("#param-resize"), display: $("#val-resize") }, + dilate: { slider: $("#param-dilate"), display: $("#val-dilate") }, + raft: { slider: $("#param-raft"), display: $("#val-raft") }, + subvideo: { slider: $("#param-subvideo"), display: $("#val-subvideo") }, + neighbor: { slider: $("#param-neighbor"), display: $("#val-neighbor") }, + stride: { slider: $("#param-stride"), display: $("#val-stride") }, +}; + +// ── Helpers ──────────────────────────────────────────────────────────── +function setStatus(msg, type = "") { + statusMessage.textContent = msg; + statusMessage.className = "status-message" + (type ? ` ${type}` : ""); +} + +function showSpinner(show) { + spinner.style.display = show ? "block" : "none"; +} + +function showLoading(msg) { + const overlay = document.createElement("div"); + overlay.className = "loading-overlay"; + overlay.id = "loading-overlay"; + overlay.innerHTML = `

${msg}

`; + document.body.appendChild(overlay); +} + +function hideLoading() { + const overlay = document.getElementById("loading-overlay"); + if (overlay) overlay.remove(); +} + +function enableStep(step) { + step.classList.remove("disabled"); +} + +function disableStep(step) { + step.classList.add("disabled"); +} + +function enableControls(enable) { + const controls = [btnPositive, btnNegative, btnAddMask, btnClearClicks, btnRemoveMasks, startSlider, endSlider]; + controls.forEach((el) => (el.disabled = !enable)); +} + +async function apiCall(url, options = {}) { + const res = await fetch(url, { + credentials: "same-origin", + ...options, + }); + if (!res.ok) { + const err = await res.json().catch(() => ({ error: "Request failed" })); + throw new Error(err.error || "Request failed"); + } + return res.json(); +} + +function drawImageOnCanvas(base64) { + return new Promise((resolve) => { + const img = new Image(); + img.onload = () => { + state.imageWidth = img.width; + state.imageHeight = img.height; + + // Scale canvas to fit container while maintaining aspect ratio + const container = canvas.parentElement; + const containerW = container.clientWidth; + const containerH = container.clientHeight; + const scale = Math.min(containerW / img.width, containerH / img.height, 1); + + canvas.width = img.width; + canvas.height = img.height; + canvas.style.width = `${img.width * scale}px`; + canvas.style.height = `${img.height * scale}px`; + state.canvasScale = scale; + + ctx.drawImage(img, 0, 0); + resolve(); + }; + img.src = "data:image/jpeg;base64," + base64; + }); +} + +function updateMaskList() { + if (state.masks.length === 0) { + maskList.innerHTML = 'No masks added'; + } else { + maskList.innerHTML = state.masks + .map((name, i) => `${name}`) + .join(""); + } +} + +// ── Upload ───────────────────────────────────────────────────────────── +uploadArea.addEventListener("click", () => videoInput.click()); +uploadArea.addEventListener("dragover", (e) => { + e.preventDefault(); + uploadArea.classList.add("dragover"); +}); +uploadArea.addEventListener("dragleave", () => uploadArea.classList.remove("dragover")); +uploadArea.addEventListener("drop", (e) => { + e.preventDefault(); + uploadArea.classList.remove("dragover"); + if (e.dataTransfer.files.length) { + handleFile(e.dataTransfer.files[0]); + } +}); +videoInput.addEventListener("change", () => { + if (videoInput.files.length) handleFile(videoInput.files[0]); +}); + +async function handleFile(file) { + if (state.busy) return; + state.busy = true; + showSpinner(true); + setStatus("Uploading and extracting frames..."); + + const formData = new FormData(); + formData.append("video", file); + + try { + const data = await apiCall("/api/upload", { + method: "POST", + body: formData, + }); + + state.totalFrames = data.total_frames; + state.currentFrame = 0; + state.endFrame = data.total_frames - 1; + state.masks = []; + + // Show video info + $("#info-name").textContent = data.video_name; + $("#info-fps").textContent = data.fps; + $("#info-frames").textContent = data.total_frames; + $("#info-size").textContent = `${data.width} x ${data.height}`; + videoInfo.style.display = "block"; + uploadArea.style.display = "none"; + + // Setup sliders + startSlider.max = data.total_frames - 1; + startSlider.value = 0; + startLabel.textContent = "0"; + endSlider.max = data.total_frames - 1; + endSlider.value = data.total_frames - 1; + endLabel.textContent = String(data.total_frames - 1); + + // Draw first frame + canvasOverlay.classList.add("hidden"); + await drawImageOnCanvas(data.first_frame); + + // Enable steps + enableStep(step2); + enableStep(step3); + enableControls(true); + updateMaskList(); + + setStatus("Video loaded. Click on the image to select objects for removal.", "success"); + } catch (err) { + setStatus("Upload failed: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +} + +// ── Reset ────────────────────────────────────────────────────────────── +btnReset.addEventListener("click", async () => { + await apiCall("/api/reset", { method: "POST" }); + location.reload(); +}); + +// ── Frame Selection ─────────────────────────────────────────────────── +startSlider.addEventListener("input", () => { + startLabel.textContent = startSlider.value; +}); + +startSlider.addEventListener("change", async () => { + if (state.busy) return; + state.busy = true; + showSpinner(true); + const frame = parseInt(startSlider.value); + state.currentFrame = frame; + setStatus(`Selecting frame ${frame}...`); + + try { + const data = await apiCall("/api/select-frame", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ frame }), + }); + await drawImageOnCanvas(data.frame); + setStatus(`Frame ${frame} selected. Click on objects to create masks.`, "success"); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +}); + +endSlider.addEventListener("input", () => { + endLabel.textContent = endSlider.value; +}); + +endSlider.addEventListener("change", async () => { + if (state.busy) return; + const endFrame = parseInt(endSlider.value); + state.endFrame = endFrame; + try { + await apiCall("/api/set-end-frame", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ frame: endFrame }), + }); + setStatus(`Tracking end frame set to ${endFrame}.`); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } +}); + +// ── Click Mode Toggle ───────────────────────────────────────────────── +btnPositive.addEventListener("click", () => { + state.isPositive = true; + btnPositive.classList.add("active"); + btnNegative.classList.remove("active"); +}); + +btnNegative.addEventListener("click", () => { + state.isPositive = false; + btnNegative.classList.add("active"); + btnPositive.classList.remove("active"); +}); + +// ── Canvas Click (SAM) ──────────────────────────────────────────────── +canvas.addEventListener("click", async (e) => { + if (state.busy) return; + state.busy = true; + showSpinner(true); + + const rect = canvas.getBoundingClientRect(); + const scaleX = canvas.width / rect.width; + const scaleY = canvas.height / rect.height; + const x = Math.round((e.clientX - rect.left) * scaleX); + const y = Math.round((e.clientY - rect.top) * scaleY); + + setStatus(`Processing click at (${x}, ${y})...`); + + try { + const data = await apiCall("/api/sam-click", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ x, y, positive: state.isPositive }), + }); + await drawImageOnCanvas(data.frame); + setStatus('Mask generated. Click "Add Mask" to keep it, or click again to refine.', "success"); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +}); + +// ── Mask Actions ────────────────────────────────────────────────────── +btnAddMask.addEventListener("click", async () => { + if (state.busy) return; + state.busy = true; + showSpinner(true); + setStatus("Adding mask..."); + + try { + const data = await apiCall("/api/add-mask", { method: "POST" }); + state.masks = data.masks; + await drawImageOnCanvas(data.frame); + updateMaskList(); + btnTrack.disabled = false; + setStatus(`Mask added (${state.masks.length} total). Add more or proceed to tracking.`, "success"); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +}); + +btnClearClicks.addEventListener("click", async () => { + if (state.busy) return; + state.busy = true; + showSpinner(true); + + try { + const data = await apiCall("/api/clear-clicks", { method: "POST" }); + await drawImageOnCanvas(data.frame); + setStatus("Clicks cleared. Click on the image to start a new mask."); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +}); + +btnRemoveMasks.addEventListener("click", async () => { + if (state.busy) return; + state.busy = true; + showSpinner(true); + + try { + const data = await apiCall("/api/remove-masks", { method: "POST" }); + state.masks = []; + await drawImageOnCanvas(data.frame); + updateMaskList(); + btnTrack.disabled = true; + btnInpaint.disabled = true; + setStatus("All masks removed."); + } catch (err) { + setStatus("Error: " + err.message, "error"); + } finally { + state.busy = false; + showSpinner(false); + } +}); + +// ── Tracking ────────────────────────────────────────────────────────── +btnTrack.addEventListener("click", async () => { + if (state.busy) return; + state.busy = true; + showLoading("Tracking objects across frames... This may take a while."); + setStatus("Running tracking..."); + + try { + const data = await apiCall("/api/track", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ masks: state.masks }), + }); + + // Show tracking result + const trackResult = $("#track-result"); + const trackVideo = $("#track-video"); + trackVideo.src = data.video_url + "?t=" + Date.now(); + trackResult.style.display = "block"; + btnInpaint.disabled = false; + setStatus("Tracking complete! You can now run inpainting.", "success"); + } catch (err) { + setStatus("Tracking failed: " + err.message, "error"); + } finally { + state.busy = false; + hideLoading(); + } +}); + +// ── Inpainting ──────────────────────────────────────────────────────── +btnInpaint.addEventListener("click", async () => { + if (state.busy) return; + state.busy = true; + showLoading("Running ProPainter inpainting... This may take a while."); + setStatus("Running inpainting..."); + + const params = { + masks: state.masks, + resize_ratio: parseFloat(paramSliders.resize.slider.value), + dilate_radius: parseInt(paramSliders.dilate.slider.value), + raft_iter: parseInt(paramSliders.raft.slider.value), + subvideo_length: parseInt(paramSliders.subvideo.slider.value), + neighbor_length: parseInt(paramSliders.neighbor.slider.value), + ref_stride: parseInt(paramSliders.stride.slider.value), + }; + + try { + const data = await apiCall("/api/inpaint", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(params), + }); + + const inpaintResult = $("#inpaint-result"); + const inpaintVideo = $("#inpaint-video"); + inpaintVideo.src = data.video_url + "?t=" + Date.now(); + inpaintResult.style.display = "block"; + setStatus("Inpainting complete! The result video is ready.", "success"); + } catch (err) { + setStatus("Inpainting failed: " + err.message, "error"); + } finally { + state.busy = false; + hideLoading(); + } +}); + +// ── Parameter sliders live update ───────────────────────────────────── +Object.values(paramSliders).forEach(({ slider, display }) => { + slider.addEventListener("input", () => { + const val = parseFloat(slider.value); + display.textContent = Number.isInteger(val) ? val : val.toFixed(2); + }); +}); diff --git a/web-demos/web_app/templates/index.html b/web-demos/web_app/templates/index.html new file mode 100644 index 00000000..6d118f11 --- /dev/null +++ b/web-demos/web_app/templates/index.html @@ -0,0 +1,215 @@ + + + + + + ProPainter - Video Inpainting + + + +
+

ProPainter

+

Improving Propagation and Transformer for Video Inpainting

+
+ +
+ +
+
+ 1 +

Upload Video

+
+
+
+
+ + + + + +

Drop your video here or click to browse

+ Supports MP4, MOV, AVI +
+ +
+ +
+
+ + +
+
+ 2 +

Add Masks

+
+
+
+
+ +
+

Upload a video to get started

+
+
+ +
+ +
+ +
+ + 0 +
+
+
+ +
+ + 0 +
+
+ + +
+ +
+ + +
+
+ + +
+ +
+ + + +
+
+ + +
+ +
+ No masks added +
+
+
+
+
+
+ + +
+
+ 3 +

Track & Inpaint

+
+
+ +
+ ProPainter Parameters +
+
+ +
+ + 1.00 +
+
+
+ +
+ + 8 +
+
+
+ +
+ + 20 +
+
+
+ +
+ + 80 +
+
+
+ +
+ + 10 +
+
+
+ +
+ + 10 +
+
+
+
+ +
+
+

Step 1: Track Objects

+

Track selected masks across all video frames using CUTIE.

+ + +
+
+

Step 2: Inpaint Video

+

Remove tracked objects and fill with ProPainter.

+ + +
+
+
+
+ + +
+
Ready. Upload a video to begin.
+ +
+
+ +
+

+ ProPainter — + Improving Propagation and Transformer for Video Inpainting (ICCV 2023) +

+
+ + + +