397 lines
12 KiB
Python
397 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TransNetV2 Scene Detection Service (PyTorch Version)
|
|
|
|
A Flask API wrapping TransNetV2 for neural shot boundary detection.
|
|
Uses PyTorch implementation with weights from HuggingFace.
|
|
Returns cut timestamps and keyframe suggestions for each scene.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
import numpy as np
|
|
import torch
|
|
from flask import Flask, request, jsonify
|
|
|
|
# Add TransNetV2 PyTorch to path
|
|
sys.path.insert(0, '/app/TransNetV2/inference-pytorch')
|
|
|
|
app = Flask(__name__)
|
|
|
|
# Configuration
|
|
MIN_SCENE_LENGTH = float(os.getenv('TRANSNET_MIN_SCENE_LENGTH', '1.0'))
|
|
THRESHOLD = float(os.getenv('TRANSNET_THRESHOLD', '0.5'))
|
|
WEIGHTS_PATH = os.getenv('TRANSNET_WEIGHTS', '/app/weights/transnetv2-pytorch-weights.pth')
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
# Model will be loaded on startup
|
|
model = None
|
|
|
|
|
|
def load_model():
|
|
"""Load TransNetV2 PyTorch model"""
|
|
global model
|
|
|
|
from transnetv2_pytorch import TransNetV2
|
|
|
|
print(f"Loading TransNetV2 PyTorch model on {DEVICE}...")
|
|
model = TransNetV2()
|
|
|
|
# Load weights
|
|
state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
|
|
model.load_state_dict(state_dict)
|
|
model.to(DEVICE)
|
|
model.eval()
|
|
|
|
print("TransNetV2 model loaded successfully")
|
|
|
|
|
|
def extract_frames(video_path):
|
|
"""
|
|
Extract frames from video using ffmpeg.
|
|
Returns numpy array of shape [N, H, W, 3] with uint8 values.
|
|
TransNetV2 expects 48x27 resolution.
|
|
"""
|
|
# Get video info first
|
|
probe_cmd = [
|
|
'ffprobe', '-v', 'error',
|
|
'-select_streams', 'v:0',
|
|
'-show_entries', 'stream=width,height,r_frame_rate',
|
|
'-show_entries', 'format=duration',
|
|
'-of', 'csv=p=0',
|
|
video_path
|
|
]
|
|
result = subprocess.run(probe_cmd, capture_output=True, text=True)
|
|
lines = result.stdout.strip().split('\n')
|
|
|
|
# Parse duration
|
|
duration = float(lines[-1]) if lines else 0
|
|
|
|
# Parse fps from r_frame_rate (e.g., "30000/1001" or "30/1")
|
|
fps_str = lines[0].split(',')[2] if len(lines) > 0 and ',' in lines[0] else "30/1"
|
|
if '/' in fps_str:
|
|
num, den = fps_str.split('/')
|
|
fps = float(num) / float(den)
|
|
else:
|
|
fps = float(fps_str)
|
|
|
|
# Extract frames at 48x27 resolution (TransNetV2 input size)
|
|
# Use a reasonable frame rate (original or max 25fps to save memory)
|
|
sample_fps = min(fps, 25)
|
|
|
|
ffmpeg_cmd = [
|
|
'ffmpeg', '-i', video_path,
|
|
'-vf', f'fps={sample_fps},scale=48:27',
|
|
'-pix_fmt', 'rgb24',
|
|
'-f', 'rawvideo',
|
|
'-'
|
|
]
|
|
|
|
result = subprocess.run(ffmpeg_cmd, capture_output=True)
|
|
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"FFmpeg failed: {result.stderr.decode()}")
|
|
|
|
# Parse raw video frames
|
|
frame_size = 48 * 27 * 3
|
|
raw_frames = result.stdout
|
|
num_frames = len(raw_frames) // frame_size
|
|
|
|
frames = np.frombuffer(raw_frames, dtype=np.uint8)
|
|
frames = frames[:num_frames * frame_size].reshape(num_frames, 27, 48, 3)
|
|
|
|
return frames, fps, duration, sample_fps
|
|
|
|
|
|
def predict_scenes(frames):
|
|
"""
|
|
Run TransNetV2 on frames to predict shot boundaries.
|
|
|
|
The model expects batches of 100 frames and outputs per-frame predictions.
|
|
We use overlapping windows for better accuracy at boundaries.
|
|
"""
|
|
WINDOW_SIZE = 100
|
|
|
|
num_frames = len(frames)
|
|
predictions = np.zeros(num_frames, dtype=np.float32)
|
|
|
|
# Process in windows of 100 frames with 50 frame overlap
|
|
for start in range(0, num_frames, WINDOW_SIZE // 2):
|
|
end = min(start + WINDOW_SIZE, num_frames)
|
|
window_frames = frames[start:end]
|
|
|
|
# Pad if necessary
|
|
if len(window_frames) < WINDOW_SIZE:
|
|
padding = np.zeros((WINDOW_SIZE - len(window_frames), 27, 48, 3), dtype=np.uint8)
|
|
window_frames = np.concatenate([window_frames, padding], axis=0)
|
|
|
|
# Convert to tensor [1, T, H, W, C]
|
|
frames_tensor = torch.from_numpy(window_frames).unsqueeze(0).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
# Model returns predictions for each frame
|
|
output = model(frames_tensor)
|
|
if isinstance(output, tuple):
|
|
output = output[0] # Get single-frame predictions
|
|
|
|
# Output is [B, T, 1] - squeeze and get probabilities
|
|
preds = torch.sigmoid(output).squeeze().cpu().numpy()
|
|
|
|
# Store predictions (use max for overlapping regions)
|
|
actual_frames = min(end - start, len(preds))
|
|
for i in range(actual_frames):
|
|
frame_idx = start + i
|
|
if frame_idx < num_frames:
|
|
predictions[frame_idx] = max(predictions[frame_idx], preds[i])
|
|
|
|
return predictions
|
|
|
|
|
|
def predictions_to_scenes(predictions, threshold):
|
|
"""
|
|
Convert per-frame predictions to scene boundaries.
|
|
Returns list of (start_frame, end_frame) tuples.
|
|
"""
|
|
scenes = []
|
|
scene_start = 0
|
|
|
|
for i in range(1, len(predictions)):
|
|
if predictions[i] > threshold:
|
|
# Shot boundary detected
|
|
if i > scene_start:
|
|
scenes.append((scene_start, i))
|
|
scene_start = i
|
|
|
|
# Add final scene
|
|
if scene_start < len(predictions):
|
|
scenes.append((scene_start, len(predictions)))
|
|
|
|
return scenes
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
"""Health check endpoint"""
|
|
return jsonify({
|
|
'status': 'ok',
|
|
'model': 'TransNetV2-PyTorch',
|
|
'device': DEVICE,
|
|
'threshold': THRESHOLD,
|
|
'min_scene_length': MIN_SCENE_LENGTH
|
|
})
|
|
|
|
|
|
@app.route('/detect', methods=['POST'])
|
|
def detect():
|
|
"""
|
|
Detect scene boundaries using TransNetV2.
|
|
|
|
Expects JSON body with:
|
|
- video_path: path to video file (must be accessible from container)
|
|
- threshold: optional override for prediction threshold (0-1)
|
|
- extract_keyframes: if true, return suggested keyframe timestamps
|
|
|
|
Returns JSON with:
|
|
- cuts: array of cut timestamps in seconds
|
|
- scenes: array of {start, end, duration, keyframe} objects
|
|
- scene_count: number of scenes detected
|
|
- duration: video duration in seconds
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'video_path' not in data:
|
|
return jsonify({'error': 'video_path is required'}), 400
|
|
|
|
video_path = data['video_path']
|
|
threshold = data.get('threshold', THRESHOLD)
|
|
extract_keyframes = data.get('extract_keyframes', True)
|
|
|
|
if not os.path.exists(video_path):
|
|
return jsonify({'error': f'Video file not found: {video_path}'}), 404
|
|
|
|
try:
|
|
# Extract frames from video
|
|
frames, original_fps, duration, sample_fps = extract_frames(video_path)
|
|
frame_count = len(frames)
|
|
|
|
if frame_count == 0:
|
|
return jsonify({'error': 'No frames extracted from video'}), 400
|
|
|
|
# Run TransNetV2 prediction
|
|
predictions = predict_scenes(frames)
|
|
|
|
# Convert predictions to scenes
|
|
scenes_frames = predictions_to_scenes(predictions, threshold)
|
|
|
|
# Convert frame indices to timestamps (using sample_fps since that's what we extracted at)
|
|
cuts = [0.0]
|
|
scenes = []
|
|
predictions_at_cuts = []
|
|
|
|
for scene_start, scene_end in scenes_frames:
|
|
start_sec = scene_start / sample_fps
|
|
end_sec = scene_end / sample_fps
|
|
scene_duration = end_sec - start_sec
|
|
|
|
# Skip very short scenes (likely noise)
|
|
if scene_duration < MIN_SCENE_LENGTH:
|
|
continue
|
|
|
|
# Calculate keyframe (midpoint by default)
|
|
keyframe_sec = start_sec + (scene_duration / 2)
|
|
|
|
# For long scenes (>2 min), use 1/3 point for more action coverage
|
|
if scene_duration > 120:
|
|
keyframe_sec = start_sec + (scene_duration / 3)
|
|
|
|
scenes.append({
|
|
'start': round(start_sec, 3),
|
|
'end': round(end_sec, 3),
|
|
'duration': round(scene_duration, 3),
|
|
'keyframe': round(keyframe_sec, 3) if extract_keyframes else None
|
|
})
|
|
|
|
# Add cut timestamp if not already present
|
|
if start_sec > 0 and start_sec not in cuts:
|
|
cuts.append(round(start_sec, 3))
|
|
# Get prediction score at this frame
|
|
if scene_start < len(predictions):
|
|
predictions_at_cuts.append(float(predictions[scene_start]))
|
|
|
|
# Always include video end
|
|
if duration not in cuts:
|
|
cuts.append(round(duration, 3))
|
|
|
|
cuts = sorted(list(set(cuts)))
|
|
|
|
# Extract keyframes list for convenience
|
|
keyframes = [s['keyframe'] for s in scenes if s['keyframe'] is not None]
|
|
|
|
return jsonify({
|
|
'cuts': cuts,
|
|
'scenes': scenes,
|
|
'keyframes': keyframes,
|
|
'scene_count': len(scenes),
|
|
'duration': round(duration, 3),
|
|
'fps': round(original_fps, 2),
|
|
'sample_fps': round(sample_fps, 2),
|
|
'frame_count': frame_count,
|
|
'threshold': threshold,
|
|
'predictions_at_cuts': predictions_at_cuts
|
|
})
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
return jsonify({
|
|
'error': str(e),
|
|
'traceback': traceback.format_exc()
|
|
}), 500
|
|
|
|
|
|
@app.route('/detect_batch', methods=['POST'])
|
|
def detect_batch():
|
|
"""
|
|
Detect scenes for multiple videos in batch.
|
|
|
|
Expects JSON body with:
|
|
- video_paths: array of paths to video files
|
|
|
|
Returns JSON with:
|
|
- results: array of detection results (same format as /detect)
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'video_paths' not in data:
|
|
return jsonify({'error': 'video_paths array is required'}), 400
|
|
|
|
results = []
|
|
for video_path in data['video_paths']:
|
|
try:
|
|
single_result = detect_single(video_path, data.get('threshold', THRESHOLD))
|
|
results.append({'video_path': video_path, **single_result})
|
|
except Exception as e:
|
|
results.append({'video_path': video_path, 'error': str(e)})
|
|
|
|
return jsonify({'results': results})
|
|
|
|
|
|
def detect_single(video_path, threshold):
|
|
"""Helper for batch detection"""
|
|
if not os.path.exists(video_path):
|
|
raise FileNotFoundError(f'Video file not found: {video_path}')
|
|
|
|
frames, original_fps, duration, sample_fps = extract_frames(video_path)
|
|
|
|
if len(frames) == 0:
|
|
raise ValueError('No frames extracted from video')
|
|
|
|
predictions = predict_scenes(frames)
|
|
scenes_frames = predictions_to_scenes(predictions, threshold)
|
|
|
|
cuts = [0.0]
|
|
scenes = []
|
|
|
|
for scene_start, scene_end in scenes_frames:
|
|
start_sec = scene_start / sample_fps
|
|
end_sec = scene_end / sample_fps
|
|
scene_duration = end_sec - start_sec
|
|
|
|
if scene_duration < MIN_SCENE_LENGTH:
|
|
continue
|
|
|
|
keyframe_sec = start_sec + (scene_duration / 2)
|
|
if scene_duration > 120:
|
|
keyframe_sec = start_sec + (scene_duration / 3)
|
|
|
|
scenes.append({
|
|
'start': round(start_sec, 3),
|
|
'end': round(end_sec, 3),
|
|
'duration': round(scene_duration, 3),
|
|
'keyframe': round(keyframe_sec, 3)
|
|
})
|
|
|
|
if start_sec > 0:
|
|
cuts.append(round(start_sec, 3))
|
|
|
|
cuts.append(round(duration, 3))
|
|
cuts = sorted(list(set(cuts)))
|
|
keyframes = [s['keyframe'] for s in scenes]
|
|
|
|
return {
|
|
'cuts': cuts,
|
|
'scenes': scenes,
|
|
'keyframes': keyframes,
|
|
'scene_count': len(scenes),
|
|
'duration': round(duration, 3)
|
|
}
|
|
|
|
|
|
@app.route('/info', methods=['GET'])
|
|
def info():
|
|
"""Get service configuration"""
|
|
return jsonify({
|
|
'service': 'transnetv2',
|
|
'model': 'TransNetV2-PyTorch',
|
|
'device': DEVICE,
|
|
'description': 'Neural shot boundary detection',
|
|
'threshold': THRESHOLD,
|
|
'min_scene_length': MIN_SCENE_LENGTH,
|
|
'features': [
|
|
'Hard cut detection',
|
|
'Soft transition detection (dissolves, fades)',
|
|
'Keyframe extraction',
|
|
'Batch processing',
|
|
'GPU acceleration'
|
|
]
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
load_model()
|
|
print(f"TransNetV2 Scene Detection Service starting...")
|
|
print(f"Device: {DEVICE}")
|
|
print(f"Threshold: {THRESHOLD}")
|
|
print(f"Min scene length: {MIN_SCENE_LENGTH}s")
|
|
app.run(host='0.0.0.0', port=5005, debug=False)
|