426 lines
12 KiB
Python
426 lines
12 KiB
Python
"""
|
|
JoyTag Visual Tagging Service
|
|
|
|
Uses JoyTag Vision Transformer (ViT-B/16) for multi-label image tagging.
|
|
Outputs 5,000+ Danbooru-style tags with confidence scores.
|
|
|
|
Model: fancyfeast/joytag (HuggingFace)
|
|
Architecture: ViT-B/16 with 91.5M parameters
|
|
Input: 448x448 RGB images (auto-padded to square)
|
|
|
|
Uses ONNX Runtime for efficient GPU inference.
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from contextlib import asynccontextmanager
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
import numpy as np
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global model and tag list
|
|
session = None
|
|
tag_list = None
|
|
input_name = None
|
|
|
|
|
|
def prepare_image(image: Image.Image, target_size: int = 448) -> np.ndarray:
|
|
"""Prepare image for model inference.
|
|
|
|
1. Pad to square with white borders
|
|
2. Resize to target_size x target_size
|
|
3. Normalize with CLIP normalization values
|
|
4. Convert to NCHW format
|
|
"""
|
|
# Pad to square
|
|
w, h = image.size
|
|
if w != h:
|
|
size = max(w, h)
|
|
new_image = Image.new("RGB", (size, size), (255, 255, 255))
|
|
paste_x = (size - w) // 2
|
|
paste_y = (size - h) // 2
|
|
new_image.paste(image, (paste_x, paste_y))
|
|
image = new_image
|
|
|
|
# Resize
|
|
image = image.resize((target_size, target_size), Image.LANCZOS)
|
|
|
|
# Convert to numpy and normalize
|
|
img_array = np.array(image).astype(np.float32) / 255.0
|
|
|
|
# CLIP normalization
|
|
mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
|
|
std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
|
|
img_array = (img_array - mean) / std
|
|
|
|
# HWC to CHW, add batch dimension
|
|
img_array = np.transpose(img_array, (2, 0, 1))
|
|
img_array = np.expand_dims(img_array, axis=0)
|
|
|
|
return img_array
|
|
|
|
|
|
def load_model():
|
|
"""Load JoyTag ONNX model from HuggingFace."""
|
|
global session, tag_list, input_name
|
|
|
|
logger.info("Loading JoyTag ONNX model...")
|
|
|
|
# Model directory
|
|
model_dir = Path(os.environ.get("MODEL_DIR", "/models/joytag"))
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
# Download model files if not present
|
|
model_path = model_dir / "model.onnx"
|
|
tags_path = model_dir / "top_tags.txt"
|
|
|
|
if not model_path.exists():
|
|
logger.info("Downloading ONNX model from HuggingFace...")
|
|
hf_hub_download(
|
|
repo_id="fancyfeast/joytag",
|
|
filename="model.onnx",
|
|
local_dir=str(model_dir)
|
|
)
|
|
logger.info("ONNX model downloaded")
|
|
|
|
if not tags_path.exists():
|
|
logger.info("Downloading tag list from HuggingFace...")
|
|
hf_hub_download(
|
|
repo_id="fancyfeast/joytag",
|
|
filename="top_tags.txt",
|
|
local_dir=str(model_dir)
|
|
)
|
|
|
|
# Load tag list
|
|
with open(tags_path, "r") as f:
|
|
tag_list = [line.strip() for line in f if line.strip()]
|
|
logger.info(f"Loaded {len(tag_list)} tags")
|
|
|
|
# Configure ONNX Runtime session
|
|
providers = []
|
|
|
|
# Try CUDA first
|
|
if "CUDAExecutionProvider" in ort.get_available_providers():
|
|
providers.append(("CUDAExecutionProvider", {
|
|
"device_id": int(os.environ.get("CUDA_DEVICE", 0)),
|
|
"arena_extend_strategy": "kSameAsRequested",
|
|
}))
|
|
logger.info("Using CUDA execution provider")
|
|
|
|
# Fallback to CPU
|
|
providers.append("CPUExecutionProvider")
|
|
|
|
# Session options for optimization
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
sess_options.intra_op_num_threads = 4
|
|
|
|
# Create inference session
|
|
session = ort.InferenceSession(
|
|
str(model_path),
|
|
sess_options=sess_options,
|
|
providers=providers
|
|
)
|
|
|
|
# Get input name
|
|
input_name = session.get_inputs()[0].name
|
|
logger.info(f"Model input name: {input_name}")
|
|
|
|
# Log which provider is being used
|
|
actual_provider = session.get_providers()[0]
|
|
logger.info(f"JoyTag model loaded successfully using {actual_provider}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load JoyTag model: {e}")
|
|
raise
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifecycle manager for model loading."""
|
|
load_model()
|
|
yield
|
|
logger.info("Shutting down JoyTag service")
|
|
|
|
|
|
app = FastAPI(
|
|
title="JoyTag Visual Tagging Service",
|
|
description="Multi-label image tagging using JoyTag Vision Transformer (ONNX)",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
class AnalyzeRequest(BaseModel):
|
|
"""Request body for single image analysis."""
|
|
image_base64: str
|
|
threshold: float = 0.35
|
|
top_k: int = 50
|
|
|
|
|
|
class BatchImage(BaseModel):
|
|
"""Single image in a batch request."""
|
|
index: int
|
|
image_base64: str
|
|
timestamp: float | None = None
|
|
|
|
|
|
class BatchAnalyzeRequest(BaseModel):
|
|
"""Request body for batch image analysis."""
|
|
images: list[BatchImage]
|
|
threshold: float = 0.35
|
|
top_k: int = 50
|
|
|
|
|
|
class Tag(BaseModel):
|
|
"""Single tag with confidence."""
|
|
tag: str
|
|
confidence: float
|
|
|
|
|
|
class AnalyzeResponse(BaseModel):
|
|
"""Response from single image analysis."""
|
|
tags: list[Tag]
|
|
inference_ms: float
|
|
model_version: str = "joytag-vit-b16"
|
|
|
|
|
|
class BatchResult(BaseModel):
|
|
"""Result for a single image in batch."""
|
|
index: int
|
|
timestamp: float | None = None
|
|
tags: list[Tag]
|
|
inference_ms: float
|
|
|
|
|
|
class BatchAnalyzeResponse(BaseModel):
|
|
"""Response from batch analysis."""
|
|
results: list[BatchResult]
|
|
batch_inference_ms: float
|
|
images_processed: int
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
status: str
|
|
model_loaded: bool
|
|
gpu_available: bool
|
|
provider: str | None = None
|
|
tag_count: int
|
|
|
|
|
|
def decode_image(image_base64: str) -> Image.Image:
|
|
"""Decode base64 image to PIL Image."""
|
|
try:
|
|
img_data = base64.b64decode(image_base64)
|
|
image = Image.open(io.BytesIO(img_data)).convert("RGB")
|
|
return image
|
|
except Exception as e:
|
|
raise HTTPException(400, f"Invalid image data: {str(e)}")
|
|
|
|
|
|
def sigmoid(x: np.ndarray) -> np.ndarray:
|
|
"""Apply sigmoid activation."""
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
|
|
def predict_tags(image: Image.Image, threshold: float, top_k: int) -> tuple[list[Tag], float]:
|
|
"""Run inference on a single image."""
|
|
start_time = time.perf_counter()
|
|
|
|
# Preprocess
|
|
input_array = prepare_image(image)
|
|
|
|
# Inference
|
|
outputs = session.run(None, {input_name: input_array})
|
|
logits = outputs[0][0] # Remove batch dimension
|
|
probs = sigmoid(logits)
|
|
|
|
# Get tags above threshold
|
|
indices = np.where(probs >= threshold)[0]
|
|
|
|
# Sort by confidence and limit to top_k
|
|
sorted_indices = indices[np.argsort(-probs[indices])][:top_k]
|
|
|
|
tags = [
|
|
Tag(tag=tag_list[i], confidence=float(probs[i]))
|
|
for i in sorted_indices
|
|
]
|
|
|
|
inference_ms = (time.perf_counter() - start_time) * 1000
|
|
return tags, inference_ms
|
|
|
|
|
|
def predict_batch(images: list[Image.Image], threshold: float, top_k: int) -> tuple[list[list[Tag]], float]:
|
|
"""Run inference on a batch of images."""
|
|
start_time = time.perf_counter()
|
|
|
|
# Preprocess all images and stack
|
|
batch_arrays = [prepare_image(img) for img in images]
|
|
batch = np.concatenate(batch_arrays, axis=0)
|
|
|
|
# Batch inference
|
|
outputs = session.run(None, {input_name: batch})
|
|
all_logits = outputs[0]
|
|
all_probs = sigmoid(all_logits)
|
|
|
|
# Process each result
|
|
all_tags = []
|
|
for probs in all_probs:
|
|
indices = np.where(probs >= threshold)[0]
|
|
sorted_indices = indices[np.argsort(-probs[indices])][:top_k]
|
|
|
|
tags = [
|
|
Tag(tag=tag_list[i], confidence=float(probs[i]))
|
|
for i in sorted_indices
|
|
]
|
|
all_tags.append(tags)
|
|
|
|
batch_time = (time.perf_counter() - start_time) * 1000
|
|
return all_tags, batch_time
|
|
|
|
|
|
@app.post("/analyze", response_model=AnalyzeResponse)
|
|
async def analyze_image(request: AnalyzeRequest):
|
|
"""
|
|
Analyze a single image and return predicted tags.
|
|
|
|
- **image_base64**: Base64 encoded image (JPEG or PNG)
|
|
- **threshold**: Minimum confidence threshold (default: 0.35)
|
|
- **top_k**: Maximum number of tags to return (default: 50)
|
|
"""
|
|
if session is None:
|
|
raise HTTPException(503, "Model not initialized")
|
|
|
|
try:
|
|
image = decode_image(request.image_base64)
|
|
tags, inference_ms = predict_tags(image, request.threshold, request.top_k)
|
|
|
|
logger.info(f"Analyzed image: {len(tags)} tags in {inference_ms:.1f}ms")
|
|
|
|
return AnalyzeResponse(
|
|
tags=tags,
|
|
inference_ms=inference_ms
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Analysis failed: {e}")
|
|
raise HTTPException(500, f"Analysis failed: {str(e)}")
|
|
|
|
|
|
@app.post("/analyze_batch", response_model=BatchAnalyzeResponse)
|
|
async def analyze_batch(request: BatchAnalyzeRequest):
|
|
"""
|
|
Analyze multiple images in a batch for better GPU efficiency.
|
|
|
|
- **images**: List of images with index and optional timestamp
|
|
- **threshold**: Minimum confidence threshold (default: 0.35)
|
|
- **top_k**: Maximum number of tags per image (default: 50)
|
|
"""
|
|
if session is None:
|
|
raise HTTPException(503, "Model not initialized")
|
|
|
|
if not request.images:
|
|
raise HTTPException(400, "No images provided")
|
|
|
|
try:
|
|
# Decode all images
|
|
images = [decode_image(img_req.image_base64) for img_req in request.images]
|
|
|
|
# Batch inference
|
|
all_tags, batch_time = predict_batch(images, request.threshold, request.top_k)
|
|
|
|
# Build results
|
|
per_image_time = batch_time / len(images)
|
|
results = [
|
|
BatchResult(
|
|
index=img_req.index,
|
|
timestamp=img_req.timestamp,
|
|
tags=all_tags[i],
|
|
inference_ms=per_image_time
|
|
)
|
|
for i, img_req in enumerate(request.images)
|
|
]
|
|
|
|
logger.info(f"Batch analyzed: {len(images)} images in {batch_time:.1f}ms")
|
|
|
|
return BatchAnalyzeResponse(
|
|
results=results,
|
|
batch_inference_ms=batch_time,
|
|
images_processed=len(images)
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Batch analysis failed: {e}")
|
|
raise HTTPException(500, f"Batch analysis failed: {str(e)}")
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health():
|
|
"""Health check endpoint."""
|
|
provider = None
|
|
gpu_available = False
|
|
|
|
if session is not None:
|
|
provider = session.get_providers()[0]
|
|
gpu_available = "CUDA" in provider
|
|
|
|
return HealthResponse(
|
|
status="ok" if session is not None else "loading",
|
|
model_loaded=session is not None,
|
|
gpu_available=gpu_available,
|
|
provider=provider,
|
|
tag_count=len(tag_list) if tag_list else 0
|
|
)
|
|
|
|
|
|
@app.get("/tags")
|
|
async def get_tags():
|
|
"""Get the full list of available tags."""
|
|
if tag_list is None:
|
|
raise HTTPException(503, "Tags not loaded")
|
|
|
|
return {
|
|
"count": len(tag_list),
|
|
"tags": tag_list
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with service info."""
|
|
return {
|
|
"service": "JoyTag Visual Tagging",
|
|
"version": "1.0.0",
|
|
"model": "joytag-vit-b16 (ONNX)",
|
|
"tags": len(tag_list) if tag_list else 0,
|
|
"endpoints": {
|
|
"/analyze": "POST - Analyze single image",
|
|
"/analyze_batch": "POST - Analyze multiple images",
|
|
"/health": "GET - Health check",
|
|
"/tags": "GET - List all available tags"
|
|
}
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
port = int(os.environ.get("PORT", 5003))
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|