408 lines
13 KiB
Python

#!/usr/bin/env python3
"""
CLIP Embeddings Service
A Flask API for extracting CLIP visual embeddings from images.
Used for visual similarity and diversity scoring in clip selection.
"""
import os
import io
import base64
import numpy as np
from flask import Flask, request, jsonify
import torch
from PIL import Image
app = Flask(__name__)
# Configuration
MODEL_NAME = os.getenv('CLIP_MODEL', 'ViT-B/32')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
GPU_BATCH_SIZE = int(os.getenv('CLIP_GPU_BATCH_SIZE', '32')) # Increase for more VRAM usage
# Load model on startup
print(f"Loading CLIP model {MODEL_NAME} on {DEVICE}...")
import clip
model, preprocess = clip.load(MODEL_NAME, device=DEVICE)
model.eval()
print(f"CLIP model loaded successfully")
def decode_image(image_data: str) -> Image.Image:
"""Decode base64 image data to PIL Image"""
if image_data.startswith('data:'):
# Remove data URL prefix
image_data = image_data.split(',', 1)[1]
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
def get_embedding(image: Image.Image) -> np.ndarray:
"""Extract CLIP embedding from PIL Image"""
with torch.no_grad():
image_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
embedding = model.encode_image(image_tensor)
# Normalize embedding
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
def get_embeddings_batch(images: list[Image.Image], batch_size: int = 32) -> list[np.ndarray]:
"""Extract CLIP embeddings from multiple PIL Images using GPU batching"""
all_embeddings = []
with torch.no_grad():
# Process in GPU batches
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
# Stack preprocessed images into a batch tensor
batch_tensors = torch.stack([preprocess(img) for img in batch_images]).to(DEVICE)
# Run batch through model (single forward pass)
batch_embeddings = model.encode_image(batch_tensors)
# Normalize embeddings
batch_embeddings = batch_embeddings / batch_embeddings.norm(dim=-1, keepdim=True)
# Convert to numpy and add to results
batch_numpy = batch_embeddings.cpu().numpy()
for emb in batch_numpy:
all_embeddings.append(emb)
return all_embeddings
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
# Verify model is actually on GPU (not just reporting CUDA available)
model_device = next(model.parameters()).device.type
return jsonify({
'status': 'ok',
'model': MODEL_NAME,
'device': DEVICE,
'model_device': model_device, # Actual location of model parameters
'cuda_available': torch.cuda.is_available(),
'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768,
'gpu_batch_size': GPU_BATCH_SIZE
})
@app.route('/embed', methods=['POST'])
def embed():
"""
Extract CLIP embedding for a single image.
Expects JSON body with:
- image: base64-encoded image data (with or without data URL prefix)
Returns JSON with:
- embedding: array of floats (512 or 768 dimensions)
- norm: L2 norm of raw embedding (before normalization)
"""
data = request.get_json()
if not data or 'image' not in data:
return jsonify({'error': 'image is required'}), 400
try:
image = decode_image(data['image'])
embedding = get_embedding(image)
return jsonify({
'embedding': embedding.tolist(),
'dimensions': len(embedding)
})
except Exception as e:
import traceback
return jsonify({
'error': str(e),
'traceback': traceback.format_exc()
}), 500
@app.route('/embed_batch', methods=['POST'])
def embed_batch():
"""
Extract CLIP embeddings for multiple images in batch.
Uses GPU batching for efficient inference - processes multiple images
in a single forward pass through the model.
Expects JSON body with:
- images: array of base64-encoded image data
- include_indices: if true, include frame indices in response
- batch_size: optional GPU batch size override (default: CLIP_GPU_BATCH_SIZE env var)
Returns JSON with:
- embeddings: array of embedding arrays
- dimensions: embedding dimensionality
- batch_size: GPU batch size used
"""
data = request.get_json()
if not data or 'images' not in data:
return jsonify({'error': 'images array is required'}), 400
images_data = data['images']
include_indices = data.get('include_indices', False)
batch_size = data.get('batch_size', GPU_BATCH_SIZE)
try:
# First pass: decode all images, track failures
decoded_images = []
failed_indices = []
valid_indices = []
for i, img_data in enumerate(images_data):
try:
image = decode_image(img_data)
decoded_images.append(image)
valid_indices.append(i)
except Exception as e:
failed_indices.append({'index': i, 'error': str(e)})
# Process valid images in GPU batches
batch_embeddings = []
if decoded_images:
batch_embeddings = get_embeddings_batch(decoded_images, batch_size=batch_size)
# Build output array with embeddings in original order
embeddings = [None] * len(images_data)
for idx, emb in zip(valid_indices, batch_embeddings):
if include_indices:
embeddings[idx] = {
'index': idx,
'embedding': emb.tolist()
}
else:
embeddings[idx] = emb.tolist()
# Fill in failed entries
for fail in failed_indices:
idx = fail['index']
if include_indices:
embeddings[idx] = {'index': idx, 'embedding': None, 'error': fail['error']}
return jsonify({
'embeddings': embeddings,
'dimensions': 512 if 'ViT-B' in MODEL_NAME else 768,
'total': len(images_data),
'successful': len(valid_indices),
'failed': failed_indices if failed_indices else None,
'batch_size': batch_size
})
except Exception as e:
import traceback
return jsonify({
'error': str(e),
'traceback': traceback.format_exc()
}), 500
@app.route('/similarity', methods=['POST'])
def similarity():
"""
Calculate cosine similarity between embeddings.
Expects JSON body with:
- embedding1: first embedding array
- embedding2: second embedding array
OR
- embedding: single embedding to compare
- embeddings: array of embeddings to compare against
Returns JSON with:
- similarity: cosine similarity score(s) (0-1)
"""
data = request.get_json()
if not data:
return jsonify({'error': 'Request body required'}), 400
try:
if 'embedding1' in data and 'embedding2' in data:
# Single pair comparison
e1 = np.array(data['embedding1'])
e2 = np.array(data['embedding2'])
sim = float(np.dot(e1, e2))
return jsonify({'similarity': sim})
elif 'embedding' in data and 'embeddings' in data:
# Compare one against many
e1 = np.array(data['embedding'])
similarities = []
for e2_data in data['embeddings']:
e2 = np.array(e2_data)
sim = float(np.dot(e1, e2))
similarities.append(sim)
return jsonify({
'similarities': similarities,
'max_similarity': max(similarities) if similarities else 0,
'min_similarity': min(similarities) if similarities else 0,
'avg_similarity': sum(similarities) / len(similarities) if similarities else 0
})
else:
return jsonify({'error': 'Provide embedding1+embedding2 or embedding+embeddings'}), 400
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/diversity', methods=['POST'])
def diversity():
"""
Calculate diversity score for a frame relative to selected frames.
Higher score = more different from existing selection.
Expects JSON body with:
- frame_embedding: embedding of candidate frame
- selected_embeddings: array of already-selected frame embeddings
Returns JSON with:
- diversity_score: 0-1 (1 = completely different, 0 = identical)
- nearest_similarity: similarity to most similar selected frame
- nearest_index: index of most similar selected frame
"""
data = request.get_json()
if not data or 'frame_embedding' not in data:
return jsonify({'error': 'frame_embedding is required'}), 400
try:
frame_emb = np.array(data['frame_embedding'])
selected_embs = data.get('selected_embeddings', [])
if not selected_embs:
return jsonify({
'diversity_score': 1.0,
'nearest_similarity': 0.0,
'nearest_index': None
})
# Calculate similarity to each selected frame
similarities = []
for emb_data in selected_embs:
emb = np.array(emb_data)
sim = float(np.dot(frame_emb, emb))
similarities.append(sim)
max_sim = max(similarities)
nearest_idx = similarities.index(max_sim)
# Diversity = inverse of max similarity
diversity = 1.0 - max_sim
return jsonify({
'diversity_score': round(diversity, 4),
'nearest_similarity': round(max_sim, 4),
'nearest_index': nearest_idx
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/cluster', methods=['POST'])
def cluster():
"""
Cluster embeddings and return diverse representatives.
Expects JSON body with:
- embeddings: array of embeddings
- num_clusters: number of clusters/representatives to select
- timestamps: optional array of timestamps for each embedding
Returns JSON with:
- representatives: indices of diverse representative frames
- clusters: cluster assignment for each frame
"""
data = request.get_json()
if not data or 'embeddings' not in data:
return jsonify({'error': 'embeddings array is required'}), 400
try:
embeddings = np.array(data['embeddings'])
num_clusters = data.get('num_clusters', 5)
timestamps = data.get('timestamps', list(range(len(embeddings))))
if len(embeddings) <= num_clusters:
return jsonify({
'representatives': list(range(len(embeddings))),
'clusters': list(range(len(embeddings)))
})
# Use k-means++ initialization for diverse selection
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
clusters = kmeans.fit_predict(embeddings)
# Find frame closest to each cluster center
representatives = []
for i in range(num_clusters):
cluster_mask = clusters == i
cluster_indices = np.where(cluster_mask)[0]
cluster_embeddings = embeddings[cluster_mask]
# Find closest to center
center = kmeans.cluster_centers_[i]
distances = np.linalg.norm(cluster_embeddings - center, axis=1)
closest_idx = cluster_indices[np.argmin(distances)]
representatives.append(int(closest_idx))
# Sort by timestamp to maintain chronological order
if timestamps:
representatives.sort(key=lambda idx: timestamps[idx])
return jsonify({
'representatives': representatives,
'clusters': clusters.tolist(),
'num_clusters': num_clusters
})
except Exception as e:
import traceback
return jsonify({
'error': str(e),
'traceback': traceback.format_exc()
}), 500
@app.route('/info', methods=['GET'])
def info():
"""Get service configuration"""
model_device = next(model.parameters()).device.type
return jsonify({
'service': 'clip-embeddings',
'model': MODEL_NAME,
'device': DEVICE,
'model_device': model_device,
'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768,
'gpu_batch_size': GPU_BATCH_SIZE,
'features': [
'Single image embedding',
'GPU-batched embedding extraction',
'Similarity calculation',
'Diversity scoring',
'K-means clustering for diverse selection'
]
})
if __name__ == '__main__':
print(f"CLIP Embeddings Service starting...")
print(f"Model: {MODEL_NAME}")
print(f"Device: {DEVICE}")
print(f"GPU Batch Size: {GPU_BATCH_SIZE}")
app.run(host='0.0.0.0', port=5006, debug=False)