#!/usr/bin/env python3 """ CLIP Embeddings Service A Flask API for extracting CLIP visual embeddings from images. Used for visual similarity and diversity scoring in clip selection. """ import os import io import base64 import numpy as np from flask import Flask, request, jsonify import torch from PIL import Image app = Flask(__name__) # Configuration MODEL_NAME = os.getenv('CLIP_MODEL', 'ViT-B/32') DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' GPU_BATCH_SIZE = int(os.getenv('CLIP_GPU_BATCH_SIZE', '32')) # Increase for more VRAM usage # Load model on startup print(f"Loading CLIP model {MODEL_NAME} on {DEVICE}...") import clip model, preprocess = clip.load(MODEL_NAME, device=DEVICE) model.eval() print(f"CLIP model loaded successfully") def decode_image(image_data: str) -> Image.Image: """Decode base64 image data to PIL Image""" if image_data.startswith('data:'): # Remove data URL prefix image_data = image_data.split(',', 1)[1] image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') def get_embedding(image: Image.Image) -> np.ndarray: """Extract CLIP embedding from PIL Image""" with torch.no_grad(): image_tensor = preprocess(image).unsqueeze(0).to(DEVICE) embedding = model.encode_image(image_tensor) # Normalize embedding embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding.cpu().numpy().flatten() def get_embeddings_batch(images: list[Image.Image], batch_size: int = 32) -> list[np.ndarray]: """Extract CLIP embeddings from multiple PIL Images using GPU batching""" all_embeddings = [] with torch.no_grad(): # Process in GPU batches for i in range(0, len(images), batch_size): batch_images = images[i:i + batch_size] # Stack preprocessed images into a batch tensor batch_tensors = torch.stack([preprocess(img) for img in batch_images]).to(DEVICE) # Run batch through model (single forward pass) batch_embeddings = model.encode_image(batch_tensors) # Normalize embeddings batch_embeddings = batch_embeddings / batch_embeddings.norm(dim=-1, keepdim=True) # Convert to numpy and add to results batch_numpy = batch_embeddings.cpu().numpy() for emb in batch_numpy: all_embeddings.append(emb) return all_embeddings @app.route('/health', methods=['GET']) def health(): """Health check endpoint""" # Verify model is actually on GPU (not just reporting CUDA available) model_device = next(model.parameters()).device.type return jsonify({ 'status': 'ok', 'model': MODEL_NAME, 'device': DEVICE, 'model_device': model_device, # Actual location of model parameters 'cuda_available': torch.cuda.is_available(), 'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768, 'gpu_batch_size': GPU_BATCH_SIZE }) @app.route('/embed', methods=['POST']) def embed(): """ Extract CLIP embedding for a single image. Expects JSON body with: - image: base64-encoded image data (with or without data URL prefix) Returns JSON with: - embedding: array of floats (512 or 768 dimensions) - norm: L2 norm of raw embedding (before normalization) """ data = request.get_json() if not data or 'image' not in data: return jsonify({'error': 'image is required'}), 400 try: image = decode_image(data['image']) embedding = get_embedding(image) return jsonify({ 'embedding': embedding.tolist(), 'dimensions': len(embedding) }) except Exception as e: import traceback return jsonify({ 'error': str(e), 'traceback': traceback.format_exc() }), 500 @app.route('/embed_batch', methods=['POST']) def embed_batch(): """ Extract CLIP embeddings for multiple images in batch. Uses GPU batching for efficient inference - processes multiple images in a single forward pass through the model. Expects JSON body with: - images: array of base64-encoded image data - include_indices: if true, include frame indices in response - batch_size: optional GPU batch size override (default: CLIP_GPU_BATCH_SIZE env var) Returns JSON with: - embeddings: array of embedding arrays - dimensions: embedding dimensionality - batch_size: GPU batch size used """ data = request.get_json() if not data or 'images' not in data: return jsonify({'error': 'images array is required'}), 400 images_data = data['images'] include_indices = data.get('include_indices', False) batch_size = data.get('batch_size', GPU_BATCH_SIZE) try: # First pass: decode all images, track failures decoded_images = [] failed_indices = [] valid_indices = [] for i, img_data in enumerate(images_data): try: image = decode_image(img_data) decoded_images.append(image) valid_indices.append(i) except Exception as e: failed_indices.append({'index': i, 'error': str(e)}) # Process valid images in GPU batches batch_embeddings = [] if decoded_images: batch_embeddings = get_embeddings_batch(decoded_images, batch_size=batch_size) # Build output array with embeddings in original order embeddings = [None] * len(images_data) for idx, emb in zip(valid_indices, batch_embeddings): if include_indices: embeddings[idx] = { 'index': idx, 'embedding': emb.tolist() } else: embeddings[idx] = emb.tolist() # Fill in failed entries for fail in failed_indices: idx = fail['index'] if include_indices: embeddings[idx] = {'index': idx, 'embedding': None, 'error': fail['error']} return jsonify({ 'embeddings': embeddings, 'dimensions': 512 if 'ViT-B' in MODEL_NAME else 768, 'total': len(images_data), 'successful': len(valid_indices), 'failed': failed_indices if failed_indices else None, 'batch_size': batch_size }) except Exception as e: import traceback return jsonify({ 'error': str(e), 'traceback': traceback.format_exc() }), 500 @app.route('/similarity', methods=['POST']) def similarity(): """ Calculate cosine similarity between embeddings. Expects JSON body with: - embedding1: first embedding array - embedding2: second embedding array OR - embedding: single embedding to compare - embeddings: array of embeddings to compare against Returns JSON with: - similarity: cosine similarity score(s) (0-1) """ data = request.get_json() if not data: return jsonify({'error': 'Request body required'}), 400 try: if 'embedding1' in data and 'embedding2' in data: # Single pair comparison e1 = np.array(data['embedding1']) e2 = np.array(data['embedding2']) sim = float(np.dot(e1, e2)) return jsonify({'similarity': sim}) elif 'embedding' in data and 'embeddings' in data: # Compare one against many e1 = np.array(data['embedding']) similarities = [] for e2_data in data['embeddings']: e2 = np.array(e2_data) sim = float(np.dot(e1, e2)) similarities.append(sim) return jsonify({ 'similarities': similarities, 'max_similarity': max(similarities) if similarities else 0, 'min_similarity': min(similarities) if similarities else 0, 'avg_similarity': sum(similarities) / len(similarities) if similarities else 0 }) else: return jsonify({'error': 'Provide embedding1+embedding2 or embedding+embeddings'}), 400 except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/diversity', methods=['POST']) def diversity(): """ Calculate diversity score for a frame relative to selected frames. Higher score = more different from existing selection. Expects JSON body with: - frame_embedding: embedding of candidate frame - selected_embeddings: array of already-selected frame embeddings Returns JSON with: - diversity_score: 0-1 (1 = completely different, 0 = identical) - nearest_similarity: similarity to most similar selected frame - nearest_index: index of most similar selected frame """ data = request.get_json() if not data or 'frame_embedding' not in data: return jsonify({'error': 'frame_embedding is required'}), 400 try: frame_emb = np.array(data['frame_embedding']) selected_embs = data.get('selected_embeddings', []) if not selected_embs: return jsonify({ 'diversity_score': 1.0, 'nearest_similarity': 0.0, 'nearest_index': None }) # Calculate similarity to each selected frame similarities = [] for emb_data in selected_embs: emb = np.array(emb_data) sim = float(np.dot(frame_emb, emb)) similarities.append(sim) max_sim = max(similarities) nearest_idx = similarities.index(max_sim) # Diversity = inverse of max similarity diversity = 1.0 - max_sim return jsonify({ 'diversity_score': round(diversity, 4), 'nearest_similarity': round(max_sim, 4), 'nearest_index': nearest_idx }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/cluster', methods=['POST']) def cluster(): """ Cluster embeddings and return diverse representatives. Expects JSON body with: - embeddings: array of embeddings - num_clusters: number of clusters/representatives to select - timestamps: optional array of timestamps for each embedding Returns JSON with: - representatives: indices of diverse representative frames - clusters: cluster assignment for each frame """ data = request.get_json() if not data or 'embeddings' not in data: return jsonify({'error': 'embeddings array is required'}), 400 try: embeddings = np.array(data['embeddings']) num_clusters = data.get('num_clusters', 5) timestamps = data.get('timestamps', list(range(len(embeddings)))) if len(embeddings) <= num_clusters: return jsonify({ 'representatives': list(range(len(embeddings))), 'clusters': list(range(len(embeddings))) }) # Use k-means++ initialization for diverse selection from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto') clusters = kmeans.fit_predict(embeddings) # Find frame closest to each cluster center representatives = [] for i in range(num_clusters): cluster_mask = clusters == i cluster_indices = np.where(cluster_mask)[0] cluster_embeddings = embeddings[cluster_mask] # Find closest to center center = kmeans.cluster_centers_[i] distances = np.linalg.norm(cluster_embeddings - center, axis=1) closest_idx = cluster_indices[np.argmin(distances)] representatives.append(int(closest_idx)) # Sort by timestamp to maintain chronological order if timestamps: representatives.sort(key=lambda idx: timestamps[idx]) return jsonify({ 'representatives': representatives, 'clusters': clusters.tolist(), 'num_clusters': num_clusters }) except Exception as e: import traceback return jsonify({ 'error': str(e), 'traceback': traceback.format_exc() }), 500 @app.route('/info', methods=['GET']) def info(): """Get service configuration""" model_device = next(model.parameters()).device.type return jsonify({ 'service': 'clip-embeddings', 'model': MODEL_NAME, 'device': DEVICE, 'model_device': model_device, 'embedding_dim': 512 if 'ViT-B' in MODEL_NAME else 768, 'gpu_batch_size': GPU_BATCH_SIZE, 'features': [ 'Single image embedding', 'GPU-batched embedding extraction', 'Similarity calculation', 'Diversity scoring', 'K-means clustering for diverse selection' ] }) if __name__ == '__main__': print(f"CLIP Embeddings Service starting...") print(f"Model: {MODEL_NAME}") print(f"Device: {DEVICE}") print(f"GPU Batch Size: {GPU_BATCH_SIZE}") app.run(host='0.0.0.0', port=5006, debug=False)