408 lines
13 KiB
Python
408 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CLIP Embeddings Service
|
|
|
|
A Flask API for extracting CLIP visual embeddings from images.
|
|
Used for visual similarity and diversity scoring in clip selection.
|
|
"""
|
|
|
|
import os
|
|
import io
|
|
import base64
|
|
import numpy as np
|
|
from flask import Flask, request, jsonify
|
|
import torch
|
|
from PIL import Image
|
|
|
|
app = Flask(__name__)
|
|
|
|
# Configuration
|
|
MODEL_NAME = os.getenv('CLIP_MODEL', 'ViT-B/32')
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
GPU_BATCH_SIZE = int(os.getenv('CLIP_GPU_BATCH_SIZE', '32')) # Increase for more VRAM usage
|
|
|
|
# Load model on startup
|
|
print(f"Loading CLIP model {MODEL_NAME} on {DEVICE}...")
|
|
import clip
|
|
model, preprocess = clip.load(MODEL_NAME, device=DEVICE)
|
|
model.eval()
|
|
print(f"CLIP model loaded successfully")
|
|
|
|
|
|
def decode_image(image_data: str) -> Image.Image:
|
|
"""Decode base64 image data to PIL Image"""
|
|
if image_data.startswith('data:'):
|
|
# Remove data URL prefix
|
|
image_data = image_data.split(',', 1)[1]
|
|
image_bytes = base64.b64decode(image_data)
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
|
|
|
|
|
def get_embedding(image: Image.Image) -> np.ndarray:
|
|
"""Extract CLIP embedding from PIL Image"""
|
|
with torch.no_grad():
|
|
image_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
|
|
embedding = model.encode_image(image_tensor)
|
|
# Normalize embedding
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
|
|
def get_embeddings_batch(images: list[Image.Image], batch_size: int = 32) -> list[np.ndarray]:
|
|
"""Extract CLIP embeddings from multiple PIL Images using GPU batching"""
|
|
all_embeddings = []
|
|
|
|
with torch.no_grad():
|
|
# Process in GPU batches
|
|
for i in range(0, len(images), batch_size):
|
|
batch_images = images[i:i + batch_size]
|
|
|
|
# Stack preprocessed images into a batch tensor
|
|
batch_tensors = torch.stack([preprocess(img) for img in batch_images]).to(DEVICE)
|
|
|
|
# Run batch through model (single forward pass)
|
|
batch_embeddings = model.encode_image(batch_tensors)
|
|
|
|
# Normalize embeddings
|
|
batch_embeddings = batch_embeddings / batch_embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
# Convert to numpy and add to results
|
|
batch_numpy = batch_embeddings.cpu().numpy()
|
|
for emb in batch_numpy:
|
|
all_embeddings.append(emb)
|
|
|
|
return all_embeddings
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
"""Health check endpoint"""
|
|
# Verify model is actually on GPU (not just reporting CUDA available)
|
|
model_device = next(model.parameters()).device.type
|
|
|
|
return jsonify({
|
|
'status': 'ok',
|
|
'model': MODEL_NAME,
|
|
'device': DEVICE,
|
|
'model_device': model_device, # Actual location of model parameters
|
|
'cuda_available': torch.cuda.is_available(),
|
|
'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768,
|
|
'gpu_batch_size': GPU_BATCH_SIZE
|
|
})
|
|
|
|
|
|
@app.route('/embed', methods=['POST'])
|
|
def embed():
|
|
"""
|
|
Extract CLIP embedding for a single image.
|
|
|
|
Expects JSON body with:
|
|
- image: base64-encoded image data (with or without data URL prefix)
|
|
|
|
Returns JSON with:
|
|
- embedding: array of floats (512 or 768 dimensions)
|
|
- norm: L2 norm of raw embedding (before normalization)
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'image' not in data:
|
|
return jsonify({'error': 'image is required'}), 400
|
|
|
|
try:
|
|
image = decode_image(data['image'])
|
|
embedding = get_embedding(image)
|
|
|
|
return jsonify({
|
|
'embedding': embedding.tolist(),
|
|
'dimensions': len(embedding)
|
|
})
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
return jsonify({
|
|
'error': str(e),
|
|
'traceback': traceback.format_exc()
|
|
}), 500
|
|
|
|
|
|
@app.route('/embed_batch', methods=['POST'])
|
|
def embed_batch():
|
|
"""
|
|
Extract CLIP embeddings for multiple images in batch.
|
|
|
|
Uses GPU batching for efficient inference - processes multiple images
|
|
in a single forward pass through the model.
|
|
|
|
Expects JSON body with:
|
|
- images: array of base64-encoded image data
|
|
- include_indices: if true, include frame indices in response
|
|
- batch_size: optional GPU batch size override (default: CLIP_GPU_BATCH_SIZE env var)
|
|
|
|
Returns JSON with:
|
|
- embeddings: array of embedding arrays
|
|
- dimensions: embedding dimensionality
|
|
- batch_size: GPU batch size used
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'images' not in data:
|
|
return jsonify({'error': 'images array is required'}), 400
|
|
|
|
images_data = data['images']
|
|
include_indices = data.get('include_indices', False)
|
|
batch_size = data.get('batch_size', GPU_BATCH_SIZE)
|
|
|
|
try:
|
|
# First pass: decode all images, track failures
|
|
decoded_images = []
|
|
failed_indices = []
|
|
valid_indices = []
|
|
|
|
for i, img_data in enumerate(images_data):
|
|
try:
|
|
image = decode_image(img_data)
|
|
decoded_images.append(image)
|
|
valid_indices.append(i)
|
|
except Exception as e:
|
|
failed_indices.append({'index': i, 'error': str(e)})
|
|
|
|
# Process valid images in GPU batches
|
|
batch_embeddings = []
|
|
if decoded_images:
|
|
batch_embeddings = get_embeddings_batch(decoded_images, batch_size=batch_size)
|
|
|
|
# Build output array with embeddings in original order
|
|
embeddings = [None] * len(images_data)
|
|
for idx, emb in zip(valid_indices, batch_embeddings):
|
|
if include_indices:
|
|
embeddings[idx] = {
|
|
'index': idx,
|
|
'embedding': emb.tolist()
|
|
}
|
|
else:
|
|
embeddings[idx] = emb.tolist()
|
|
|
|
# Fill in failed entries
|
|
for fail in failed_indices:
|
|
idx = fail['index']
|
|
if include_indices:
|
|
embeddings[idx] = {'index': idx, 'embedding': None, 'error': fail['error']}
|
|
|
|
return jsonify({
|
|
'embeddings': embeddings,
|
|
'dimensions': 512 if 'ViT-B' in MODEL_NAME else 768,
|
|
'total': len(images_data),
|
|
'successful': len(valid_indices),
|
|
'failed': failed_indices if failed_indices else None,
|
|
'batch_size': batch_size
|
|
})
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
return jsonify({
|
|
'error': str(e),
|
|
'traceback': traceback.format_exc()
|
|
}), 500
|
|
|
|
|
|
@app.route('/similarity', methods=['POST'])
|
|
def similarity():
|
|
"""
|
|
Calculate cosine similarity between embeddings.
|
|
|
|
Expects JSON body with:
|
|
- embedding1: first embedding array
|
|
- embedding2: second embedding array
|
|
OR
|
|
- embedding: single embedding to compare
|
|
- embeddings: array of embeddings to compare against
|
|
|
|
Returns JSON with:
|
|
- similarity: cosine similarity score(s) (0-1)
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data:
|
|
return jsonify({'error': 'Request body required'}), 400
|
|
|
|
try:
|
|
if 'embedding1' in data and 'embedding2' in data:
|
|
# Single pair comparison
|
|
e1 = np.array(data['embedding1'])
|
|
e2 = np.array(data['embedding2'])
|
|
sim = float(np.dot(e1, e2))
|
|
return jsonify({'similarity': sim})
|
|
|
|
elif 'embedding' in data and 'embeddings' in data:
|
|
# Compare one against many
|
|
e1 = np.array(data['embedding'])
|
|
similarities = []
|
|
for e2_data in data['embeddings']:
|
|
e2 = np.array(e2_data)
|
|
sim = float(np.dot(e1, e2))
|
|
similarities.append(sim)
|
|
|
|
return jsonify({
|
|
'similarities': similarities,
|
|
'max_similarity': max(similarities) if similarities else 0,
|
|
'min_similarity': min(similarities) if similarities else 0,
|
|
'avg_similarity': sum(similarities) / len(similarities) if similarities else 0
|
|
})
|
|
|
|
else:
|
|
return jsonify({'error': 'Provide embedding1+embedding2 or embedding+embeddings'}), 400
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/diversity', methods=['POST'])
|
|
def diversity():
|
|
"""
|
|
Calculate diversity score for a frame relative to selected frames.
|
|
Higher score = more different from existing selection.
|
|
|
|
Expects JSON body with:
|
|
- frame_embedding: embedding of candidate frame
|
|
- selected_embeddings: array of already-selected frame embeddings
|
|
|
|
Returns JSON with:
|
|
- diversity_score: 0-1 (1 = completely different, 0 = identical)
|
|
- nearest_similarity: similarity to most similar selected frame
|
|
- nearest_index: index of most similar selected frame
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'frame_embedding' not in data:
|
|
return jsonify({'error': 'frame_embedding is required'}), 400
|
|
|
|
try:
|
|
frame_emb = np.array(data['frame_embedding'])
|
|
selected_embs = data.get('selected_embeddings', [])
|
|
|
|
if not selected_embs:
|
|
return jsonify({
|
|
'diversity_score': 1.0,
|
|
'nearest_similarity': 0.0,
|
|
'nearest_index': None
|
|
})
|
|
|
|
# Calculate similarity to each selected frame
|
|
similarities = []
|
|
for emb_data in selected_embs:
|
|
emb = np.array(emb_data)
|
|
sim = float(np.dot(frame_emb, emb))
|
|
similarities.append(sim)
|
|
|
|
max_sim = max(similarities)
|
|
nearest_idx = similarities.index(max_sim)
|
|
|
|
# Diversity = inverse of max similarity
|
|
diversity = 1.0 - max_sim
|
|
|
|
return jsonify({
|
|
'diversity_score': round(diversity, 4),
|
|
'nearest_similarity': round(max_sim, 4),
|
|
'nearest_index': nearest_idx
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/cluster', methods=['POST'])
|
|
def cluster():
|
|
"""
|
|
Cluster embeddings and return diverse representatives.
|
|
|
|
Expects JSON body with:
|
|
- embeddings: array of embeddings
|
|
- num_clusters: number of clusters/representatives to select
|
|
- timestamps: optional array of timestamps for each embedding
|
|
|
|
Returns JSON with:
|
|
- representatives: indices of diverse representative frames
|
|
- clusters: cluster assignment for each frame
|
|
"""
|
|
data = request.get_json()
|
|
|
|
if not data or 'embeddings' not in data:
|
|
return jsonify({'error': 'embeddings array is required'}), 400
|
|
|
|
try:
|
|
embeddings = np.array(data['embeddings'])
|
|
num_clusters = data.get('num_clusters', 5)
|
|
timestamps = data.get('timestamps', list(range(len(embeddings))))
|
|
|
|
if len(embeddings) <= num_clusters:
|
|
return jsonify({
|
|
'representatives': list(range(len(embeddings))),
|
|
'clusters': list(range(len(embeddings)))
|
|
})
|
|
|
|
# Use k-means++ initialization for diverse selection
|
|
from sklearn.cluster import KMeans
|
|
|
|
kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
|
|
clusters = kmeans.fit_predict(embeddings)
|
|
|
|
# Find frame closest to each cluster center
|
|
representatives = []
|
|
for i in range(num_clusters):
|
|
cluster_mask = clusters == i
|
|
cluster_indices = np.where(cluster_mask)[0]
|
|
cluster_embeddings = embeddings[cluster_mask]
|
|
|
|
# Find closest to center
|
|
center = kmeans.cluster_centers_[i]
|
|
distances = np.linalg.norm(cluster_embeddings - center, axis=1)
|
|
closest_idx = cluster_indices[np.argmin(distances)]
|
|
representatives.append(int(closest_idx))
|
|
|
|
# Sort by timestamp to maintain chronological order
|
|
if timestamps:
|
|
representatives.sort(key=lambda idx: timestamps[idx])
|
|
|
|
return jsonify({
|
|
'representatives': representatives,
|
|
'clusters': clusters.tolist(),
|
|
'num_clusters': num_clusters
|
|
})
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
return jsonify({
|
|
'error': str(e),
|
|
'traceback': traceback.format_exc()
|
|
}), 500
|
|
|
|
|
|
@app.route('/info', methods=['GET'])
|
|
def info():
|
|
"""Get service configuration"""
|
|
model_device = next(model.parameters()).device.type
|
|
|
|
return jsonify({
|
|
'service': 'clip-embeddings',
|
|
'model': MODEL_NAME,
|
|
'device': DEVICE,
|
|
'model_device': model_device,
|
|
'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768,
|
|
'gpu_batch_size': GPU_BATCH_SIZE,
|
|
'features': [
|
|
'Single image embedding',
|
|
'GPU-batched embedding extraction',
|
|
'Similarity calculation',
|
|
'Diversity scoring',
|
|
'K-means clustering for diverse selection'
|
|
]
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print(f"CLIP Embeddings Service starting...")
|
|
print(f"Model: {MODEL_NAME}")
|
|
print(f"Device: {DEVICE}")
|
|
print(f"GPU Batch Size: {GPU_BATCH_SIZE}")
|
|
app.run(host='0.0.0.0', port=5006, debug=False)
|