""" JoyTag Visual Tagging Service Uses JoyTag Vision Transformer (ViT-B/16) for multi-label image tagging. Outputs 5,000+ Danbooru-style tags with confidence scores. Model: fancyfeast/joytag (HuggingFace) Architecture: ViT-B/16 with 91.5M parameters Input: 448x448 RGB images (auto-padded to square) Uses ONNX Runtime for efficient GPU inference. """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager import onnxruntime as ort from PIL import Image import numpy as np import base64 import io import logging import os import time from pathlib import Path from huggingface_hub import hf_hub_download # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model and tag list session = None tag_list = None input_name = None def prepare_image(image: Image.Image, target_size: int = 448) -> np.ndarray: """Prepare image for model inference. 1. Pad to square with white borders 2. Resize to target_size x target_size 3. Normalize with CLIP normalization values 4. Convert to NCHW format """ # Pad to square w, h = image.size if w != h: size = max(w, h) new_image = Image.new("RGB", (size, size), (255, 255, 255)) paste_x = (size - w) // 2 paste_y = (size - h) // 2 new_image.paste(image, (paste_x, paste_y)) image = new_image # Resize image = image.resize((target_size, target_size), Image.LANCZOS) # Convert to numpy and normalize img_array = np.array(image).astype(np.float32) / 255.0 # CLIP normalization mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) img_array = (img_array - mean) / std # HWC to CHW, add batch dimension img_array = np.transpose(img_array, (2, 0, 1)) img_array = np.expand_dims(img_array, axis=0) return img_array def load_model(): """Load JoyTag ONNX model from HuggingFace.""" global session, tag_list, input_name logger.info("Loading JoyTag ONNX model...") # Model directory model_dir = Path(os.environ.get("MODEL_DIR", "/models/joytag")) model_dir.mkdir(parents=True, exist_ok=True) try: # Download model files if not present model_path = model_dir / "model.onnx" tags_path = model_dir / "top_tags.txt" if not model_path.exists(): logger.info("Downloading ONNX model from HuggingFace...") hf_hub_download( repo_id="fancyfeast/joytag", filename="model.onnx", local_dir=str(model_dir) ) logger.info("ONNX model downloaded") if not tags_path.exists(): logger.info("Downloading tag list from HuggingFace...") hf_hub_download( repo_id="fancyfeast/joytag", filename="top_tags.txt", local_dir=str(model_dir) ) # Load tag list with open(tags_path, "r") as f: tag_list = [line.strip() for line in f if line.strip()] logger.info(f"Loaded {len(tag_list)} tags") # Configure ONNX Runtime session providers = [] # Try CUDA first if "CUDAExecutionProvider" in ort.get_available_providers(): providers.append(("CUDAExecutionProvider", { "device_id": int(os.environ.get("CUDA_DEVICE", 0)), "arena_extend_strategy": "kSameAsRequested", })) logger.info("Using CUDA execution provider") # Fallback to CPU providers.append("CPUExecutionProvider") # Session options for optimization sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 4 # Create inference session session = ort.InferenceSession( str(model_path), sess_options=sess_options, providers=providers ) # Get input name input_name = session.get_inputs()[0].name logger.info(f"Model input name: {input_name}") # Log which provider is being used actual_provider = session.get_providers()[0] logger.info(f"JoyTag model loaded successfully using {actual_provider}") except Exception as e: logger.error(f"Failed to load JoyTag model: {e}") raise @asynccontextmanager async def lifespan(app: FastAPI): """Lifecycle manager for model loading.""" load_model() yield logger.info("Shutting down JoyTag service") app = FastAPI( title="JoyTag Visual Tagging Service", description="Multi-label image tagging using JoyTag Vision Transformer (ONNX)", version="1.0.0", lifespan=lifespan ) class AnalyzeRequest(BaseModel): """Request body for single image analysis.""" image_base64: str threshold: float = 0.35 top_k: int = 50 class BatchImage(BaseModel): """Single image in a batch request.""" index: int image_base64: str timestamp: float | None = None class BatchAnalyzeRequest(BaseModel): """Request body for batch image analysis.""" images: list[BatchImage] threshold: float = 0.35 top_k: int = 50 class Tag(BaseModel): """Single tag with confidence.""" tag: str confidence: float class AnalyzeResponse(BaseModel): """Response from single image analysis.""" tags: list[Tag] inference_ms: float model_version: str = "joytag-vit-b16" class BatchResult(BaseModel): """Result for a single image in batch.""" index: int timestamp: float | None = None tags: list[Tag] inference_ms: float class BatchAnalyzeResponse(BaseModel): """Response from batch analysis.""" results: list[BatchResult] batch_inference_ms: float images_processed: int class HealthResponse(BaseModel): """Health check response.""" status: str model_loaded: bool gpu_available: bool provider: str | None = None tag_count: int def decode_image(image_base64: str) -> Image.Image: """Decode base64 image to PIL Image.""" try: img_data = base64.b64decode(image_base64) image = Image.open(io.BytesIO(img_data)).convert("RGB") return image except Exception as e: raise HTTPException(400, f"Invalid image data: {str(e)}") def sigmoid(x: np.ndarray) -> np.ndarray: """Apply sigmoid activation.""" return 1 / (1 + np.exp(-x)) def predict_tags(image: Image.Image, threshold: float, top_k: int) -> tuple[list[Tag], float]: """Run inference on a single image.""" start_time = time.perf_counter() # Preprocess input_array = prepare_image(image) # Inference outputs = session.run(None, {input_name: input_array}) logits = outputs[0][0] # Remove batch dimension probs = sigmoid(logits) # Get tags above threshold indices = np.where(probs >= threshold)[0] # Sort by confidence and limit to top_k sorted_indices = indices[np.argsort(-probs[indices])][:top_k] tags = [ Tag(tag=tag_list[i], confidence=float(probs[i])) for i in sorted_indices ] inference_ms = (time.perf_counter() - start_time) * 1000 return tags, inference_ms def predict_batch(images: list[Image.Image], threshold: float, top_k: int) -> tuple[list[list[Tag]], float]: """Run inference on a batch of images.""" start_time = time.perf_counter() # Preprocess all images and stack batch_arrays = [prepare_image(img) for img in images] batch = np.concatenate(batch_arrays, axis=0) # Batch inference outputs = session.run(None, {input_name: batch}) all_logits = outputs[0] all_probs = sigmoid(all_logits) # Process each result all_tags = [] for probs in all_probs: indices = np.where(probs >= threshold)[0] sorted_indices = indices[np.argsort(-probs[indices])][:top_k] tags = [ Tag(tag=tag_list[i], confidence=float(probs[i])) for i in sorted_indices ] all_tags.append(tags) batch_time = (time.perf_counter() - start_time) * 1000 return all_tags, batch_time @app.post("/analyze", response_model=AnalyzeResponse) async def analyze_image(request: AnalyzeRequest): """ Analyze a single image and return predicted tags. - **image_base64**: Base64 encoded image (JPEG or PNG) - **threshold**: Minimum confidence threshold (default: 0.35) - **top_k**: Maximum number of tags to return (default: 50) """ if session is None: raise HTTPException(503, "Model not initialized") try: image = decode_image(request.image_base64) tags, inference_ms = predict_tags(image, request.threshold, request.top_k) logger.info(f"Analyzed image: {len(tags)} tags in {inference_ms:.1f}ms") return AnalyzeResponse( tags=tags, inference_ms=inference_ms ) except HTTPException: raise except Exception as e: logger.error(f"Analysis failed: {e}") raise HTTPException(500, f"Analysis failed: {str(e)}") @app.post("/analyze_batch", response_model=BatchAnalyzeResponse) async def analyze_batch(request: BatchAnalyzeRequest): """ Analyze multiple images in a batch for better GPU efficiency. - **images**: List of images with index and optional timestamp - **threshold**: Minimum confidence threshold (default: 0.35) - **top_k**: Maximum number of tags per image (default: 50) """ if session is None: raise HTTPException(503, "Model not initialized") if not request.images: raise HTTPException(400, "No images provided") try: # Decode all images images = [decode_image(img_req.image_base64) for img_req in request.images] # Batch inference all_tags, batch_time = predict_batch(images, request.threshold, request.top_k) # Build results per_image_time = batch_time / len(images) results = [ BatchResult( index=img_req.index, timestamp=img_req.timestamp, tags=all_tags[i], inference_ms=per_image_time ) for i, img_req in enumerate(request.images) ] logger.info(f"Batch analyzed: {len(images)} images in {batch_time:.1f}ms") return BatchAnalyzeResponse( results=results, batch_inference_ms=batch_time, images_processed=len(images) ) except HTTPException: raise except Exception as e: logger.error(f"Batch analysis failed: {e}") raise HTTPException(500, f"Batch analysis failed: {str(e)}") @app.get("/health", response_model=HealthResponse) async def health(): """Health check endpoint.""" provider = None gpu_available = False if session is not None: provider = session.get_providers()[0] gpu_available = "CUDA" in provider return HealthResponse( status="ok" if session is not None else "loading", model_loaded=session is not None, gpu_available=gpu_available, provider=provider, tag_count=len(tag_list) if tag_list else 0 ) @app.get("/tags") async def get_tags(): """Get the full list of available tags.""" if tag_list is None: raise HTTPException(503, "Tags not loaded") return { "count": len(tag_list), "tags": tag_list } @app.get("/") async def root(): """Root endpoint with service info.""" return { "service": "JoyTag Visual Tagging", "version": "1.0.0", "model": "joytag-vit-b16 (ONNX)", "tags": len(tag_list) if tag_list else 0, "endpoints": { "/analyze": "POST - Analyze single image", "/analyze_batch": "POST - Analyze multiple images", "/health": "GET - Health check", "/tags": "GET - List all available tags" } } if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 5003)) uvicorn.run(app, host="0.0.0.0", port=port)