#!/usr/bin/env python3 """ TransNetV2 Scene Detection Service (PyTorch Version) A Flask API wrapping TransNetV2 for neural shot boundary detection. Uses PyTorch implementation with weights from HuggingFace. Returns cut timestamps and keyframe suggestions for each scene. """ import os import sys import subprocess import numpy as np import torch from flask import Flask, request, jsonify # Add TransNetV2 PyTorch to path sys.path.insert(0, '/app/TransNetV2/inference-pytorch') app = Flask(__name__) # Configuration MIN_SCENE_LENGTH = float(os.getenv('TRANSNET_MIN_SCENE_LENGTH', '1.0')) THRESHOLD = float(os.getenv('TRANSNET_THRESHOLD', '0.5')) WEIGHTS_PATH = os.getenv('TRANSNET_WEIGHTS', '/app/weights/transnetv2-pytorch-weights.pth') DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Model will be loaded on startup model = None def load_model(): """Load TransNetV2 PyTorch model""" global model from transnetv2_pytorch import TransNetV2 print(f"Loading TransNetV2 PyTorch model on {DEVICE}...") model = TransNetV2() # Load weights state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() print("TransNetV2 model loaded successfully") def extract_frames(video_path): """ Extract frames from video using ffmpeg. Returns numpy array of shape [N, H, W, 3] with uint8 values. TransNetV2 expects 48x27 resolution. """ # Get video info first probe_cmd = [ 'ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=width,height,r_frame_rate', '-show_entries', 'format=duration', '-of', 'csv=p=0', video_path ] result = subprocess.run(probe_cmd, capture_output=True, text=True) lines = result.stdout.strip().split('\n') # Parse duration duration = float(lines[-1]) if lines else 0 # Parse fps from r_frame_rate (e.g., "30000/1001" or "30/1") fps_str = lines[0].split(',')[2] if len(lines) > 0 and ',' in lines[0] else "30/1" if '/' in fps_str: num, den = fps_str.split('/') fps = float(num) / float(den) else: fps = float(fps_str) # Extract frames at 48x27 resolution (TransNetV2 input size) # Use a reasonable frame rate (original or max 25fps to save memory) sample_fps = min(fps, 25) ffmpeg_cmd = [ 'ffmpeg', '-i', video_path, '-vf', f'fps={sample_fps},scale=48:27', '-pix_fmt', 'rgb24', '-f', 'rawvideo', '-' ] result = subprocess.run(ffmpeg_cmd, capture_output=True) if result.returncode != 0: raise RuntimeError(f"FFmpeg failed: {result.stderr.decode()}") # Parse raw video frames frame_size = 48 * 27 * 3 raw_frames = result.stdout num_frames = len(raw_frames) // frame_size frames = np.frombuffer(raw_frames, dtype=np.uint8) frames = frames[:num_frames * frame_size].reshape(num_frames, 27, 48, 3) return frames, fps, duration, sample_fps def predict_scenes(frames): """ Run TransNetV2 on frames to predict shot boundaries. The model expects batches of 100 frames and outputs per-frame predictions. We use overlapping windows for better accuracy at boundaries. """ WINDOW_SIZE = 100 num_frames = len(frames) predictions = np.zeros(num_frames, dtype=np.float32) # Process in windows of 100 frames with 50 frame overlap for start in range(0, num_frames, WINDOW_SIZE // 2): end = min(start + WINDOW_SIZE, num_frames) window_frames = frames[start:end] # Pad if necessary if len(window_frames) < WINDOW_SIZE: padding = np.zeros((WINDOW_SIZE - len(window_frames), 27, 48, 3), dtype=np.uint8) window_frames = np.concatenate([window_frames, padding], axis=0) # Convert to tensor [1, T, H, W, C] frames_tensor = torch.from_numpy(window_frames).unsqueeze(0).to(DEVICE) with torch.no_grad(): # Model returns predictions for each frame output = model(frames_tensor) if isinstance(output, tuple): output = output[0] # Get single-frame predictions # Output is [B, T, 1] - squeeze and get probabilities preds = torch.sigmoid(output).squeeze().cpu().numpy() # Store predictions (use max for overlapping regions) actual_frames = min(end - start, len(preds)) for i in range(actual_frames): frame_idx = start + i if frame_idx < num_frames: predictions[frame_idx] = max(predictions[frame_idx], preds[i]) return predictions def predictions_to_scenes(predictions, threshold): """ Convert per-frame predictions to scene boundaries. Returns list of (start_frame, end_frame) tuples. """ scenes = [] scene_start = 0 for i in range(1, len(predictions)): if predictions[i] > threshold: # Shot boundary detected if i > scene_start: scenes.append((scene_start, i)) scene_start = i # Add final scene if scene_start < len(predictions): scenes.append((scene_start, len(predictions))) return scenes @app.route('/health', methods=['GET']) def health(): """Health check endpoint""" return jsonify({ 'status': 'ok', 'model': 'TransNetV2-PyTorch', 'device': DEVICE, 'threshold': THRESHOLD, 'min_scene_length': MIN_SCENE_LENGTH }) @app.route('/detect', methods=['POST']) def detect(): """ Detect scene boundaries using TransNetV2. Expects JSON body with: - video_path: path to video file (must be accessible from container) - threshold: optional override for prediction threshold (0-1) - extract_keyframes: if true, return suggested keyframe timestamps Returns JSON with: - cuts: array of cut timestamps in seconds - scenes: array of {start, end, duration, keyframe} objects - scene_count: number of scenes detected - duration: video duration in seconds """ data = request.get_json() if not data or 'video_path' not in data: return jsonify({'error': 'video_path is required'}), 400 video_path = data['video_path'] threshold = data.get('threshold', THRESHOLD) extract_keyframes = data.get('extract_keyframes', True) if not os.path.exists(video_path): return jsonify({'error': f'Video file not found: {video_path}'}), 404 try: # Extract frames from video frames, original_fps, duration, sample_fps = extract_frames(video_path) frame_count = len(frames) if frame_count == 0: return jsonify({'error': 'No frames extracted from video'}), 400 # Run TransNetV2 prediction predictions = predict_scenes(frames) # Convert predictions to scenes scenes_frames = predictions_to_scenes(predictions, threshold) # Convert frame indices to timestamps (using sample_fps since that's what we extracted at) cuts = [0.0] scenes = [] predictions_at_cuts = [] for scene_start, scene_end in scenes_frames: start_sec = scene_start / sample_fps end_sec = scene_end / sample_fps scene_duration = end_sec - start_sec # Skip very short scenes (likely noise) if scene_duration < MIN_SCENE_LENGTH: continue # Calculate keyframe (midpoint by default) keyframe_sec = start_sec + (scene_duration / 2) # For long scenes (>2 min), use 1/3 point for more action coverage if scene_duration > 120: keyframe_sec = start_sec + (scene_duration / 3) scenes.append({ 'start': round(start_sec, 3), 'end': round(end_sec, 3), 'duration': round(scene_duration, 3), 'keyframe': round(keyframe_sec, 3) if extract_keyframes else None }) # Add cut timestamp if not already present if start_sec > 0 and start_sec not in cuts: cuts.append(round(start_sec, 3)) # Get prediction score at this frame if scene_start < len(predictions): predictions_at_cuts.append(float(predictions[scene_start])) # Always include video end if duration not in cuts: cuts.append(round(duration, 3)) cuts = sorted(list(set(cuts))) # Extract keyframes list for convenience keyframes = [s['keyframe'] for s in scenes if s['keyframe'] is not None] return jsonify({ 'cuts': cuts, 'scenes': scenes, 'keyframes': keyframes, 'scene_count': len(scenes), 'duration': round(duration, 3), 'fps': round(original_fps, 2), 'sample_fps': round(sample_fps, 2), 'frame_count': frame_count, 'threshold': threshold, 'predictions_at_cuts': predictions_at_cuts }) except Exception as e: import traceback return jsonify({ 'error': str(e), 'traceback': traceback.format_exc() }), 500 @app.route('/detect_batch', methods=['POST']) def detect_batch(): """ Detect scenes for multiple videos in batch. Expects JSON body with: - video_paths: array of paths to video files Returns JSON with: - results: array of detection results (same format as /detect) """ data = request.get_json() if not data or 'video_paths' not in data: return jsonify({'error': 'video_paths array is required'}), 400 results = [] for video_path in data['video_paths']: try: single_result = detect_single(video_path, data.get('threshold', THRESHOLD)) results.append({'video_path': video_path, **single_result}) except Exception as e: results.append({'video_path': video_path, 'error': str(e)}) return jsonify({'results': results}) def detect_single(video_path, threshold): """Helper for batch detection""" if not os.path.exists(video_path): raise FileNotFoundError(f'Video file not found: {video_path}') frames, original_fps, duration, sample_fps = extract_frames(video_path) if len(frames) == 0: raise ValueError('No frames extracted from video') predictions = predict_scenes(frames) scenes_frames = predictions_to_scenes(predictions, threshold) cuts = [0.0] scenes = [] for scene_start, scene_end in scenes_frames: start_sec = scene_start / sample_fps end_sec = scene_end / sample_fps scene_duration = end_sec - start_sec if scene_duration < MIN_SCENE_LENGTH: continue keyframe_sec = start_sec + (scene_duration / 2) if scene_duration > 120: keyframe_sec = start_sec + (scene_duration / 3) scenes.append({ 'start': round(start_sec, 3), 'end': round(end_sec, 3), 'duration': round(scene_duration, 3), 'keyframe': round(keyframe_sec, 3) }) if start_sec > 0: cuts.append(round(start_sec, 3)) cuts.append(round(duration, 3)) cuts = sorted(list(set(cuts))) keyframes = [s['keyframe'] for s in scenes] return { 'cuts': cuts, 'scenes': scenes, 'keyframes': keyframes, 'scene_count': len(scenes), 'duration': round(duration, 3) } @app.route('/info', methods=['GET']) def info(): """Get service configuration""" return jsonify({ 'service': 'transnetv2', 'model': 'TransNetV2-PyTorch', 'device': DEVICE, 'description': 'Neural shot boundary detection', 'threshold': THRESHOLD, 'min_scene_length': MIN_SCENE_LENGTH, 'features': [ 'Hard cut detection', 'Soft transition detection (dissolves, fades)', 'Keyframe extraction', 'Batch processing', 'GPU acceleration' ] }) if __name__ == '__main__': load_model() print(f"TransNetV2 Scene Detection Service starting...") print(f"Device: {DEVICE}") print(f"Threshold: {THRESHOLD}") print(f"Min scene length: {MIN_SCENE_LENGTH}s") app.run(host='0.0.0.0', port=5005, debug=False)