426 lines
12 KiB
Python

"""
JoyTag Visual Tagging Service
Uses JoyTag Vision Transformer (ViT-B/16) for multi-label image tagging.
Outputs 5,000+ Danbooru-style tags with confidence scores.
Model: fancyfeast/joytag (HuggingFace)
Architecture: ViT-B/16 with 91.5M parameters
Input: 448x448 RGB images (auto-padded to square)
Uses ONNX Runtime for efficient GPU inference.
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import onnxruntime as ort
from PIL import Image
import numpy as np
import base64
import io
import logging
import os
import time
from pathlib import Path
from huggingface_hub import hf_hub_download
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model and tag list
session = None
tag_list = None
input_name = None
def prepare_image(image: Image.Image, target_size: int = 448) -> np.ndarray:
"""Prepare image for model inference.
1. Pad to square with white borders
2. Resize to target_size x target_size
3. Normalize with CLIP normalization values
4. Convert to NCHW format
"""
# Pad to square
w, h = image.size
if w != h:
size = max(w, h)
new_image = Image.new("RGB", (size, size), (255, 255, 255))
paste_x = (size - w) // 2
paste_y = (size - h) // 2
new_image.paste(image, (paste_x, paste_y))
image = new_image
# Resize
image = image.resize((target_size, target_size), Image.LANCZOS)
# Convert to numpy and normalize
img_array = np.array(image).astype(np.float32) / 255.0
# CLIP normalization
mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
img_array = (img_array - mean) / std
# HWC to CHW, add batch dimension
img_array = np.transpose(img_array, (2, 0, 1))
img_array = np.expand_dims(img_array, axis=0)
return img_array
def load_model():
"""Load JoyTag ONNX model from HuggingFace."""
global session, tag_list, input_name
logger.info("Loading JoyTag ONNX model...")
# Model directory
model_dir = Path(os.environ.get("MODEL_DIR", "/models/joytag"))
model_dir.mkdir(parents=True, exist_ok=True)
try:
# Download model files if not present
model_path = model_dir / "model.onnx"
tags_path = model_dir / "top_tags.txt"
if not model_path.exists():
logger.info("Downloading ONNX model from HuggingFace...")
hf_hub_download(
repo_id="fancyfeast/joytag",
filename="model.onnx",
local_dir=str(model_dir)
)
logger.info("ONNX model downloaded")
if not tags_path.exists():
logger.info("Downloading tag list from HuggingFace...")
hf_hub_download(
repo_id="fancyfeast/joytag",
filename="top_tags.txt",
local_dir=str(model_dir)
)
# Load tag list
with open(tags_path, "r") as f:
tag_list = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(tag_list)} tags")
# Configure ONNX Runtime session
providers = []
# Try CUDA first
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.append(("CUDAExecutionProvider", {
"device_id": int(os.environ.get("CUDA_DEVICE", 0)),
"arena_extend_strategy": "kSameAsRequested",
}))
logger.info("Using CUDA execution provider")
# Fallback to CPU
providers.append("CPUExecutionProvider")
# Session options for optimization
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
# Create inference session
session = ort.InferenceSession(
str(model_path),
sess_options=sess_options,
providers=providers
)
# Get input name
input_name = session.get_inputs()[0].name
logger.info(f"Model input name: {input_name}")
# Log which provider is being used
actual_provider = session.get_providers()[0]
logger.info(f"JoyTag model loaded successfully using {actual_provider}")
except Exception as e:
logger.error(f"Failed to load JoyTag model: {e}")
raise
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifecycle manager for model loading."""
load_model()
yield
logger.info("Shutting down JoyTag service")
app = FastAPI(
title="JoyTag Visual Tagging Service",
description="Multi-label image tagging using JoyTag Vision Transformer (ONNX)",
version="1.0.0",
lifespan=lifespan
)
class AnalyzeRequest(BaseModel):
"""Request body for single image analysis."""
image_base64: str
threshold: float = 0.35
top_k: int = 50
class BatchImage(BaseModel):
"""Single image in a batch request."""
index: int
image_base64: str
timestamp: float | None = None
class BatchAnalyzeRequest(BaseModel):
"""Request body for batch image analysis."""
images: list[BatchImage]
threshold: float = 0.35
top_k: int = 50
class Tag(BaseModel):
"""Single tag with confidence."""
tag: str
confidence: float
class AnalyzeResponse(BaseModel):
"""Response from single image analysis."""
tags: list[Tag]
inference_ms: float
model_version: str = "joytag-vit-b16"
class BatchResult(BaseModel):
"""Result for a single image in batch."""
index: int
timestamp: float | None = None
tags: list[Tag]
inference_ms: float
class BatchAnalyzeResponse(BaseModel):
"""Response from batch analysis."""
results: list[BatchResult]
batch_inference_ms: float
images_processed: int
class HealthResponse(BaseModel):
"""Health check response."""
status: str
model_loaded: bool
gpu_available: bool
provider: str | None = None
tag_count: int
def decode_image(image_base64: str) -> Image.Image:
"""Decode base64 image to PIL Image."""
try:
img_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(img_data)).convert("RGB")
return image
except Exception as e:
raise HTTPException(400, f"Invalid image data: {str(e)}")
def sigmoid(x: np.ndarray) -> np.ndarray:
"""Apply sigmoid activation."""
return 1 / (1 + np.exp(-x))
def predict_tags(image: Image.Image, threshold: float, top_k: int) -> tuple[list[Tag], float]:
"""Run inference on a single image."""
start_time = time.perf_counter()
# Preprocess
input_array = prepare_image(image)
# Inference
outputs = session.run(None, {input_name: input_array})
logits = outputs[0][0] # Remove batch dimension
probs = sigmoid(logits)
# Get tags above threshold
indices = np.where(probs >= threshold)[0]
# Sort by confidence and limit to top_k
sorted_indices = indices[np.argsort(-probs[indices])][:top_k]
tags = [
Tag(tag=tag_list[i], confidence=float(probs[i]))
for i in sorted_indices
]
inference_ms = (time.perf_counter() - start_time) * 1000
return tags, inference_ms
def predict_batch(images: list[Image.Image], threshold: float, top_k: int) -> tuple[list[list[Tag]], float]:
"""Run inference on a batch of images."""
start_time = time.perf_counter()
# Preprocess all images and stack
batch_arrays = [prepare_image(img) for img in images]
batch = np.concatenate(batch_arrays, axis=0)
# Batch inference
outputs = session.run(None, {input_name: batch})
all_logits = outputs[0]
all_probs = sigmoid(all_logits)
# Process each result
all_tags = []
for probs in all_probs:
indices = np.where(probs >= threshold)[0]
sorted_indices = indices[np.argsort(-probs[indices])][:top_k]
tags = [
Tag(tag=tag_list[i], confidence=float(probs[i]))
for i in sorted_indices
]
all_tags.append(tags)
batch_time = (time.perf_counter() - start_time) * 1000
return all_tags, batch_time
@app.post("/analyze", response_model=AnalyzeResponse)
async def analyze_image(request: AnalyzeRequest):
"""
Analyze a single image and return predicted tags.
- **image_base64**: Base64 encoded image (JPEG or PNG)
- **threshold**: Minimum confidence threshold (default: 0.35)
- **top_k**: Maximum number of tags to return (default: 50)
"""
if session is None:
raise HTTPException(503, "Model not initialized")
try:
image = decode_image(request.image_base64)
tags, inference_ms = predict_tags(image, request.threshold, request.top_k)
logger.info(f"Analyzed image: {len(tags)} tags in {inference_ms:.1f}ms")
return AnalyzeResponse(
tags=tags,
inference_ms=inference_ms
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Analysis failed: {e}")
raise HTTPException(500, f"Analysis failed: {str(e)}")
@app.post("/analyze_batch", response_model=BatchAnalyzeResponse)
async def analyze_batch(request: BatchAnalyzeRequest):
"""
Analyze multiple images in a batch for better GPU efficiency.
- **images**: List of images with index and optional timestamp
- **threshold**: Minimum confidence threshold (default: 0.35)
- **top_k**: Maximum number of tags per image (default: 50)
"""
if session is None:
raise HTTPException(503, "Model not initialized")
if not request.images:
raise HTTPException(400, "No images provided")
try:
# Decode all images
images = [decode_image(img_req.image_base64) for img_req in request.images]
# Batch inference
all_tags, batch_time = predict_batch(images, request.threshold, request.top_k)
# Build results
per_image_time = batch_time / len(images)
results = [
BatchResult(
index=img_req.index,
timestamp=img_req.timestamp,
tags=all_tags[i],
inference_ms=per_image_time
)
for i, img_req in enumerate(request.images)
]
logger.info(f"Batch analyzed: {len(images)} images in {batch_time:.1f}ms")
return BatchAnalyzeResponse(
results=results,
batch_inference_ms=batch_time,
images_processed=len(images)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Batch analysis failed: {e}")
raise HTTPException(500, f"Batch analysis failed: {str(e)}")
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
provider = None
gpu_available = False
if session is not None:
provider = session.get_providers()[0]
gpu_available = "CUDA" in provider
return HealthResponse(
status="ok" if session is not None else "loading",
model_loaded=session is not None,
gpu_available=gpu_available,
provider=provider,
tag_count=len(tag_list) if tag_list else 0
)
@app.get("/tags")
async def get_tags():
"""Get the full list of available tags."""
if tag_list is None:
raise HTTPException(503, "Tags not loaded")
return {
"count": len(tag_list),
"tags": tag_list
}
@app.get("/")
async def root():
"""Root endpoint with service info."""
return {
"service": "JoyTag Visual Tagging",
"version": "1.0.0",
"model": "joytag-vit-b16 (ONNX)",
"tags": len(tag_list) if tag_list else 0,
"endpoints": {
"/analyze": "POST - Analyze single image",
"/analyze_batch": "POST - Analyze multiple images",
"/health": "GET - Health check",
"/tags": "GET - List all available tags"
}
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 5003))
uvicorn.run(app, host="0.0.0.0", port=port)