""" AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model """ import json import os import re import time import traceback import spaces import gradio as gr import torch from transformers import AutoModelForImageTextToText, AutoProcessor from qwen_vl_utils import process_vision_info from PIL import Image, ImageDraw, ImageFont # --------------------------------------------------------------------------- # Model loading (happens once at Space startup; weights stay on CPU until # @spaces.GPU moves them to GPU on demand) # --------------------------------------------------------------------------- MODEL_ID = "jiang-cc/AD-Copilot" processor = AutoProcessor.from_pretrained( MODEL_ID, min_pixels=64 * 28 * 28, max_pixels=1280 * 28 * 28, trust_remote_code=True, ) try: import flash_attn # noqa: F401 _attn_impl = "flash_attention_2" except ImportError: _attn_impl = "sdpa" model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation=_attn_impl, trust_remote_code=True, ).to("cuda").eval() print(f"[AD-Copilot] Attention: {_attn_impl} | Device: {model.device}", flush=True) # --------------------------------------------------------------------------- # BBox visualization # --------------------------------------------------------------------------- COLORS = [ "#FF4444", "#44AA44", "#4488FF", "#FF8800", "#AA44FF", "#00CCCC", "#FF44AA", "#88AA00", ] def parse_bboxes(text): """Try to extract bbox JSON from model output.""" pattern = r'```(?:json)?\s*(\[.*?\])\s*```' match = re.search(pattern, text, re.DOTALL) if match: raw = match.group(1) else: match = re.search(r'(\[\s*\{.*?\}\s*\])', text, re.DOTALL) if match: raw = match.group(1) else: return None try: bboxes = json.loads(raw) if isinstance(bboxes, list) and len(bboxes) > 0 and "bbox_2d" in bboxes[0]: # Normalize: accept both "label" and "bbox_label" for b in bboxes: if "label" not in b and "bbox_label" in b: b["label"] = b.pop("bbox_label") return bboxes except json.JSONDecodeError: pass return None def draw_bboxes(image, bboxes): """Draw bounding boxes with labels on image.""" img = image.copy() draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13) except (IOError, OSError): font = ImageFont.load_default() small_font = font for i, bbox_info in enumerate(bboxes): bbox = bbox_info.get("bbox_2d", []) label = bbox_info.get("label", f"defect_{i}") if len(bbox) != 4: continue x1, y1, x2, y2 = bbox color = COLORS[i % len(COLORS)] for w in range(3): draw.rectangle([x1 - w, y1 - w, x2 + w, y2 + w], outline=color) text_bbox = draw.textbbox((0, 0), label, font=small_font) tw = text_bbox[2] - text_bbox[0] + 8 th = text_bbox[3] - text_bbox[1] + 6 label_y = max(0, y1 - th - 2) draw.rectangle([x1, label_y, x1 + tw, label_y + th], fill=color) draw.text((x1 + 4, label_y + 2), label, fill="white", font=small_font) return img # --------------------------------------------------------------------------- # Inference (supports both single-image and paired-image modes) # --------------------------------------------------------------------------- def _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): """Core inference logic shared by both predict functions.""" t_gpu_ready = time.time() # GPU is allocated by this point has_ref = reference_image is not None has_test = test_image is not None if not has_ref and not has_test: return "Please upload at least one image.", None try: t0 = time.time() max_new_tokens = int(max_new_tokens) content = [] if has_ref and has_test: ref = reference_image.copy() tst = test_image.copy() ref.thumbnail((512, 512), Image.Resampling.LANCZOS) tst.thumbnail((512, 512), Image.Resampling.LANCZOS) content.append({"type": "image", "image": ref}) content.append({"type": "image", "image": tst}) vis_source = tst elif has_test: tst = test_image.copy() tst.thumbnail((512, 512), Image.Resampling.LANCZOS) content.append({"type": "image", "image": tst}) vis_source = tst else: ref = reference_image.copy() ref.thumbnail((512, 512), Image.Resampling.LANCZOS) content.append({"type": "image", "image": ref}) vis_source = ref content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(model.device) t_preprocess = time.time() generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) t_generate = time.time() generated_ids_trimmed = [ out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids) ] n_tokens = generated_ids.shape[1] - inputs.input_ids.shape[1] output = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] bboxes = parse_bboxes(output) vis_image = None if bboxes: vis_image = draw_bboxes(vis_source, bboxes) gpu_load_t = t_gpu_ready - _t_enter if _t_enter else 0 prep_t = t_preprocess - t0 gen_t = t_generate - t_preprocess tps = n_tokens / gen_t if gen_t > 0 else 0 parts = [] if gpu_load_t > 0.5: parts.append(f"GPU Load: {gpu_load_t:.1f}s") parts.append(f"Preprocess: {prep_t:.1f}s") parts.append(f"Generate: {gen_t:.1f}s ({n_tokens} tokens, {tps:.1f} tok/s)") output += f"\n\n---\n[{_attn_impl}] {' | '.join(parts)}" return output, vis_image except Exception as e: tb = traceback.format_exc() print(tb, flush=True) return f"Error:\n{tb}", None def _wrap_predict(fn): def wrapper(reference_image, test_image, prompt, max_new_tokens): t_enter = time.time() return fn(reference_image, test_image, prompt, max_new_tokens, _t_enter=t_enter) return wrapper @spaces.GPU(duration=5) def _predict_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter) @spaces.GPU(duration=10) def _predict_long_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter) predict = _wrap_predict(_predict_gpu) predict_long = _wrap_predict(_predict_long_gpu) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- TITLE = "AD-Copilot: Comparison-Aware Anomaly Detection" DESCRIPTION = """ **AD-Copilot** extends Qwen2.5-VL with a novel **comparison-aware visual encoder** that generates special comparison tokens capturing differences between a reference image and a test image, achieving state-of-the-art results on industrial anomaly detection benchmarks. **Two modes:** Upload both images for comparison-based inspection, or just a test image for single-image tasks (counting, OCR, etc.). [[Paper]](https://arxiv.org/abs/2603.13779v1) | [[Code]](https://github.com/jam-cc/AD-Copilot) | [[Model]](https://huggingface.co/jiang-cc/AD-Copilot) """ EXAMPLES = [ # 1. Anomaly Discrimination (yes/no) — bottle contamination [ "examples/bottle_good.jpg", "examples/bottle_contamination.jpg", "The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.", 128, ], # 2. Defect Description — cable cut [ "examples/cable_good.jpg", "examples/cable_cut.jpg", "The first image is a normal sample. Compared with the first image, please describe the anomaly in the second image in detail.", 256, ], # 3. Fine-grained Defect Localization — bottle contamination [ "examples/bottle_good.jpg", "examples/bottle_contamination.jpg", "The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.", 256, ], # 4. Defect Localization — PCB wrong position [ "examples/pcb_good.jpg", "examples/pcb_defect.jpg", "The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.", 256, ], # 5. Object Counting (single image) — candle [ None, "examples/candle_count.jpg", "How many candles are in this image? For each candle, give its bounding box as {\"bbox_2d\": [x1,y1,x2,y2], \"label\": \"candle_1\"}, numbering them 1, 2, 3, etc. State the total count at the end.", 512, ], # 7. Industrial OCR (single image) — drink bottle [ None, "examples/drink_bottle_ocr.jpg", "Please read all text and labels on this product.", 256, ], ] with gr.Blocks(theme=gr.themes.Soft(), title=TITLE) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): ref_img = gr.Image( label="Reference (Good) Image (optional)", type="pil", height=300, ) with gr.Column(scale=1): test_img = gr.Image( label="Test Image", type="pil", height=300, ) prompt = gr.Textbox( label="Prompt", value="The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.", lines=2, ) with gr.Row(): max_tokens = gr.Slider( minimum=16, maximum=1024, value=128, step=16, label="Max New Tokens", ) run_btn = gr.Button("Run (5s)", variant="primary", scale=2) run_long_btn = gr.Button("Run Long (10s)", variant="secondary", scale=1) output = gr.Textbox(label="Model Output", lines=4) vis_output = gr.Image(label="Detection Visualization") run_btn.click( fn=predict, inputs=[ref_img, test_img, prompt, max_tokens], outputs=[output, vis_output], ) run_long_btn.click( fn=predict_long, inputs=[ref_img, test_img, prompt, max_tokens], outputs=[output, vis_output], ) gr.Examples( examples=EXAMPLES, inputs=[ref_img, test_img, prompt, max_tokens], outputs=[output, vis_output], fn=predict, cache_examples=False, ) if __name__ == "__main__": demo.launch()