397 lines
12 KiB
Python

#!/usr/bin/env python3
"""
TransNetV2 Scene Detection Service (PyTorch Version)
A Flask API wrapping TransNetV2 for neural shot boundary detection.
Uses PyTorch implementation with weights from HuggingFace.
Returns cut timestamps and keyframe suggestions for each scene.
"""
import os
import sys
import subprocess
import numpy as np
import torch
from flask import Flask, request, jsonify
# Add TransNetV2 PyTorch to path
sys.path.insert(0, '/app/TransNetV2/inference-pytorch')
app = Flask(__name__)
# Configuration
MIN_SCENE_LENGTH = float(os.getenv('TRANSNET_MIN_SCENE_LENGTH', '1.0'))
THRESHOLD = float(os.getenv('TRANSNET_THRESHOLD', '0.5'))
WEIGHTS_PATH = os.getenv('TRANSNET_WEIGHTS', '/app/weights/transnetv2-pytorch-weights.pth')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model will be loaded on startup
model = None
def load_model():
"""Load TransNetV2 PyTorch model"""
global model
from transnetv2_pytorch import TransNetV2
print(f"Loading TransNetV2 PyTorch model on {DEVICE}...")
model = TransNetV2()
# Load weights
state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("TransNetV2 model loaded successfully")
def extract_frames(video_path):
"""
Extract frames from video using ffmpeg.
Returns numpy array of shape [N, H, W, 3] with uint8 values.
TransNetV2 expects 48x27 resolution.
"""
# Get video info first
probe_cmd = [
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=width,height,r_frame_rate',
'-show_entries', 'format=duration',
'-of', 'csv=p=0',
video_path
]
result = subprocess.run(probe_cmd, capture_output=True, text=True)
lines = result.stdout.strip().split('\n')
# Parse duration
duration = float(lines[-1]) if lines else 0
# Parse fps from r_frame_rate (e.g., "30000/1001" or "30/1")
fps_str = lines[0].split(',')[2] if len(lines) > 0 and ',' in lines[0] else "30/1"
if '/' in fps_str:
num, den = fps_str.split('/')
fps = float(num) / float(den)
else:
fps = float(fps_str)
# Extract frames at 48x27 resolution (TransNetV2 input size)
# Use a reasonable frame rate (original or max 25fps to save memory)
sample_fps = min(fps, 25)
ffmpeg_cmd = [
'ffmpeg', '-i', video_path,
'-vf', f'fps={sample_fps},scale=48:27',
'-pix_fmt', 'rgb24',
'-f', 'rawvideo',
'-'
]
result = subprocess.run(ffmpeg_cmd, capture_output=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed: {result.stderr.decode()}")
# Parse raw video frames
frame_size = 48 * 27 * 3
raw_frames = result.stdout
num_frames = len(raw_frames) // frame_size
frames = np.frombuffer(raw_frames, dtype=np.uint8)
frames = frames[:num_frames * frame_size].reshape(num_frames, 27, 48, 3)
return frames, fps, duration, sample_fps
def predict_scenes(frames):
"""
Run TransNetV2 on frames to predict shot boundaries.
The model expects batches of 100 frames and outputs per-frame predictions.
We use overlapping windows for better accuracy at boundaries.
"""
WINDOW_SIZE = 100
num_frames = len(frames)
predictions = np.zeros(num_frames, dtype=np.float32)
# Process in windows of 100 frames with 50 frame overlap
for start in range(0, num_frames, WINDOW_SIZE // 2):
end = min(start + WINDOW_SIZE, num_frames)
window_frames = frames[start:end]
# Pad if necessary
if len(window_frames) < WINDOW_SIZE:
padding = np.zeros((WINDOW_SIZE - len(window_frames), 27, 48, 3), dtype=np.uint8)
window_frames = np.concatenate([window_frames, padding], axis=0)
# Convert to tensor [1, T, H, W, C]
frames_tensor = torch.from_numpy(window_frames).unsqueeze(0).to(DEVICE)
with torch.no_grad():
# Model returns predictions for each frame
output = model(frames_tensor)
if isinstance(output, tuple):
output = output[0] # Get single-frame predictions
# Output is [B, T, 1] - squeeze and get probabilities
preds = torch.sigmoid(output).squeeze().cpu().numpy()
# Store predictions (use max for overlapping regions)
actual_frames = min(end - start, len(preds))
for i in range(actual_frames):
frame_idx = start + i
if frame_idx < num_frames:
predictions[frame_idx] = max(predictions[frame_idx], preds[i])
return predictions
def predictions_to_scenes(predictions, threshold):
"""
Convert per-frame predictions to scene boundaries.
Returns list of (start_frame, end_frame) tuples.
"""
scenes = []
scene_start = 0
for i in range(1, len(predictions)):
if predictions[i] > threshold:
# Shot boundary detected
if i > scene_start:
scenes.append((scene_start, i))
scene_start = i
# Add final scene
if scene_start < len(predictions):
scenes.append((scene_start, len(predictions)))
return scenes
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({
'status': 'ok',
'model': 'TransNetV2-PyTorch',
'device': DEVICE,
'threshold': THRESHOLD,
'min_scene_length': MIN_SCENE_LENGTH
})
@app.route('/detect', methods=['POST'])
def detect():
"""
Detect scene boundaries using TransNetV2.
Expects JSON body with:
- video_path: path to video file (must be accessible from container)
- threshold: optional override for prediction threshold (0-1)
- extract_keyframes: if true, return suggested keyframe timestamps
Returns JSON with:
- cuts: array of cut timestamps in seconds
- scenes: array of {start, end, duration, keyframe} objects
- scene_count: number of scenes detected
- duration: video duration in seconds
"""
data = request.get_json()
if not data or 'video_path' not in data:
return jsonify({'error': 'video_path is required'}), 400
video_path = data['video_path']
threshold = data.get('threshold', THRESHOLD)
extract_keyframes = data.get('extract_keyframes', True)
if not os.path.exists(video_path):
return jsonify({'error': f'Video file not found: {video_path}'}), 404
try:
# Extract frames from video
frames, original_fps, duration, sample_fps = extract_frames(video_path)
frame_count = len(frames)
if frame_count == 0:
return jsonify({'error': 'No frames extracted from video'}), 400
# Run TransNetV2 prediction
predictions = predict_scenes(frames)
# Convert predictions to scenes
scenes_frames = predictions_to_scenes(predictions, threshold)
# Convert frame indices to timestamps (using sample_fps since that's what we extracted at)
cuts = [0.0]
scenes = []
predictions_at_cuts = []
for scene_start, scene_end in scenes_frames:
start_sec = scene_start / sample_fps
end_sec = scene_end / sample_fps
scene_duration = end_sec - start_sec
# Skip very short scenes (likely noise)
if scene_duration < MIN_SCENE_LENGTH:
continue
# Calculate keyframe (midpoint by default)
keyframe_sec = start_sec + (scene_duration / 2)
# For long scenes (>2 min), use 1/3 point for more action coverage
if scene_duration > 120:
keyframe_sec = start_sec + (scene_duration / 3)
scenes.append({
'start': round(start_sec, 3),
'end': round(end_sec, 3),
'duration': round(scene_duration, 3),
'keyframe': round(keyframe_sec, 3) if extract_keyframes else None
})
# Add cut timestamp if not already present
if start_sec > 0 and start_sec not in cuts:
cuts.append(round(start_sec, 3))
# Get prediction score at this frame
if scene_start < len(predictions):
predictions_at_cuts.append(float(predictions[scene_start]))
# Always include video end
if duration not in cuts:
cuts.append(round(duration, 3))
cuts = sorted(list(set(cuts)))
# Extract keyframes list for convenience
keyframes = [s['keyframe'] for s in scenes if s['keyframe'] is not None]
return jsonify({
'cuts': cuts,
'scenes': scenes,
'keyframes': keyframes,
'scene_count': len(scenes),
'duration': round(duration, 3),
'fps': round(original_fps, 2),
'sample_fps': round(sample_fps, 2),
'frame_count': frame_count,
'threshold': threshold,
'predictions_at_cuts': predictions_at_cuts
})
except Exception as e:
import traceback
return jsonify({
'error': str(e),
'traceback': traceback.format_exc()
}), 500
@app.route('/detect_batch', methods=['POST'])
def detect_batch():
"""
Detect scenes for multiple videos in batch.
Expects JSON body with:
- video_paths: array of paths to video files
Returns JSON with:
- results: array of detection results (same format as /detect)
"""
data = request.get_json()
if not data or 'video_paths' not in data:
return jsonify({'error': 'video_paths array is required'}), 400
results = []
for video_path in data['video_paths']:
try:
single_result = detect_single(video_path, data.get('threshold', THRESHOLD))
results.append({'video_path': video_path, **single_result})
except Exception as e:
results.append({'video_path': video_path, 'error': str(e)})
return jsonify({'results': results})
def detect_single(video_path, threshold):
"""Helper for batch detection"""
if not os.path.exists(video_path):
raise FileNotFoundError(f'Video file not found: {video_path}')
frames, original_fps, duration, sample_fps = extract_frames(video_path)
if len(frames) == 0:
raise ValueError('No frames extracted from video')
predictions = predict_scenes(frames)
scenes_frames = predictions_to_scenes(predictions, threshold)
cuts = [0.0]
scenes = []
for scene_start, scene_end in scenes_frames:
start_sec = scene_start / sample_fps
end_sec = scene_end / sample_fps
scene_duration = end_sec - start_sec
if scene_duration < MIN_SCENE_LENGTH:
continue
keyframe_sec = start_sec + (scene_duration / 2)
if scene_duration > 120:
keyframe_sec = start_sec + (scene_duration / 3)
scenes.append({
'start': round(start_sec, 3),
'end': round(end_sec, 3),
'duration': round(scene_duration, 3),
'keyframe': round(keyframe_sec, 3)
})
if start_sec > 0:
cuts.append(round(start_sec, 3))
cuts.append(round(duration, 3))
cuts = sorted(list(set(cuts)))
keyframes = [s['keyframe'] for s in scenes]
return {
'cuts': cuts,
'scenes': scenes,
'keyframes': keyframes,
'scene_count': len(scenes),
'duration': round(duration, 3)
}
@app.route('/info', methods=['GET'])
def info():
"""Get service configuration"""
return jsonify({
'service': 'transnetv2',
'model': 'TransNetV2-PyTorch',
'device': DEVICE,
'description': 'Neural shot boundary detection',
'threshold': THRESHOLD,
'min_scene_length': MIN_SCENE_LENGTH,
'features': [
'Hard cut detection',
'Soft transition detection (dissolves, fades)',
'Keyframe extraction',
'Batch processing',
'GPU acceleration'
]
})
if __name__ == '__main__':
load_model()
print(f"TransNetV2 Scene Detection Service starting...")
print(f"Device: {DEVICE}")
print(f"Threshold: {THRESHOLD}")
print(f"Min scene length: {MIN_SCENE_LENGTH}s")
app.run(host='0.0.0.0', port=5005, debug=False)