Spaces:
Running on Zero
Running on Zero
| """ | |
| AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model | |
| """ | |
| import json | |
| import os | |
| import re | |
| import time | |
| import traceback | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| from PIL import Image, ImageDraw, ImageFont | |
| # --------------------------------------------------------------------------- | |
| # Model loading (happens once at Space startup; weights stay on CPU until | |
| # @spaces.GPU moves them to GPU on demand) | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "jiang-cc/AD-Copilot" | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| min_pixels=64 * 28 * 28, | |
| max_pixels=1280 * 28 * 28, | |
| trust_remote_code=True, | |
| ) | |
| try: | |
| import flash_attn # noqa: F401 | |
| _attn_impl = "flash_attention_2" | |
| except ImportError: | |
| _attn_impl = "sdpa" | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=_attn_impl, | |
| trust_remote_code=True, | |
| ).to("cuda").eval() | |
| print(f"[AD-Copilot] Attention: {_attn_impl} | Device: {model.device}", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # BBox visualization | |
| # --------------------------------------------------------------------------- | |
| COLORS = [ | |
| "#FF4444", "#44AA44", "#4488FF", "#FF8800", | |
| "#AA44FF", "#00CCCC", "#FF44AA", "#88AA00", | |
| ] | |
| def parse_bboxes(text): | |
| """Try to extract bbox JSON from model output.""" | |
| pattern = r'```(?:json)?\s*(\[.*?\])\s*```' | |
| match = re.search(pattern, text, re.DOTALL) | |
| if match: | |
| raw = match.group(1) | |
| else: | |
| match = re.search(r'(\[\s*\{.*?\}\s*\])', text, re.DOTALL) | |
| if match: | |
| raw = match.group(1) | |
| else: | |
| return None | |
| try: | |
| bboxes = json.loads(raw) | |
| if isinstance(bboxes, list) and len(bboxes) > 0 and "bbox_2d" in bboxes[0]: | |
| # Normalize: accept both "label" and "bbox_label" | |
| for b in bboxes: | |
| if "label" not in b and "bbox_label" in b: | |
| b["label"] = b.pop("bbox_label") | |
| return bboxes | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def draw_bboxes(image, bboxes): | |
| """Draw bounding boxes with labels on image.""" | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) | |
| small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13) | |
| except (IOError, OSError): | |
| font = ImageFont.load_default() | |
| small_font = font | |
| for i, bbox_info in enumerate(bboxes): | |
| bbox = bbox_info.get("bbox_2d", []) | |
| label = bbox_info.get("label", f"defect_{i}") | |
| if len(bbox) != 4: | |
| continue | |
| x1, y1, x2, y2 = bbox | |
| color = COLORS[i % len(COLORS)] | |
| for w in range(3): | |
| draw.rectangle([x1 - w, y1 - w, x2 + w, y2 + w], outline=color) | |
| text_bbox = draw.textbbox((0, 0), label, font=small_font) | |
| tw = text_bbox[2] - text_bbox[0] + 8 | |
| th = text_bbox[3] - text_bbox[1] + 6 | |
| label_y = max(0, y1 - th - 2) | |
| draw.rectangle([x1, label_y, x1 + tw, label_y + th], fill=color) | |
| draw.text((x1 + 4, label_y + 2), label, fill="white", font=small_font) | |
| return img | |
| # --------------------------------------------------------------------------- | |
| # Inference (supports both single-image and paired-image modes) | |
| # --------------------------------------------------------------------------- | |
| def _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): | |
| """Core inference logic shared by both predict functions.""" | |
| t_gpu_ready = time.time() # GPU is allocated by this point | |
| has_ref = reference_image is not None | |
| has_test = test_image is not None | |
| if not has_ref and not has_test: | |
| return "Please upload at least one image.", None | |
| try: | |
| t0 = time.time() | |
| max_new_tokens = int(max_new_tokens) | |
| content = [] | |
| if has_ref and has_test: | |
| ref = reference_image.copy() | |
| tst = test_image.copy() | |
| ref.thumbnail((512, 512), Image.Resampling.LANCZOS) | |
| tst.thumbnail((512, 512), Image.Resampling.LANCZOS) | |
| content.append({"type": "image", "image": ref}) | |
| content.append({"type": "image", "image": tst}) | |
| vis_source = tst | |
| elif has_test: | |
| tst = test_image.copy() | |
| tst.thumbnail((512, 512), Image.Resampling.LANCZOS) | |
| content.append({"type": "image", "image": tst}) | |
| vis_source = tst | |
| else: | |
| ref = reference_image.copy() | |
| ref.thumbnail((512, 512), Image.Resampling.LANCZOS) | |
| content.append({"type": "image", "image": ref}) | |
| vis_source = ref | |
| content.append({"type": "text", "text": prompt}) | |
| messages = [{"role": "user", "content": content}] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| t_preprocess = time.time() | |
| generated_ids = model.generate( | |
| **inputs, max_new_tokens=max_new_tokens, do_sample=False | |
| ) | |
| t_generate = time.time() | |
| generated_ids_trimmed = [ | |
| out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids) | |
| ] | |
| n_tokens = generated_ids.shape[1] - inputs.input_ids.shape[1] | |
| output = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| bboxes = parse_bboxes(output) | |
| vis_image = None | |
| if bboxes: | |
| vis_image = draw_bboxes(vis_source, bboxes) | |
| gpu_load_t = t_gpu_ready - _t_enter if _t_enter else 0 | |
| prep_t = t_preprocess - t0 | |
| gen_t = t_generate - t_preprocess | |
| tps = n_tokens / gen_t if gen_t > 0 else 0 | |
| parts = [] | |
| if gpu_load_t > 0.5: | |
| parts.append(f"GPU Load: {gpu_load_t:.1f}s") | |
| parts.append(f"Preprocess: {prep_t:.1f}s") | |
| parts.append(f"Generate: {gen_t:.1f}s ({n_tokens} tokens, {tps:.1f} tok/s)") | |
| output += f"\n\n---\n[{_attn_impl}] {' | '.join(parts)}" | |
| return output, vis_image | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| print(tb, flush=True) | |
| return f"Error:\n{tb}", None | |
| def _wrap_predict(fn): | |
| def wrapper(reference_image, test_image, prompt, max_new_tokens): | |
| t_enter = time.time() | |
| return fn(reference_image, test_image, prompt, max_new_tokens, _t_enter=t_enter) | |
| return wrapper | |
| def _predict_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): | |
| return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter) | |
| def _predict_long_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None): | |
| return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter) | |
| predict = _wrap_predict(_predict_gpu) | |
| predict_long = _wrap_predict(_predict_long_gpu) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| TITLE = "AD-Copilot: Comparison-Aware Anomaly Detection" | |
| DESCRIPTION = """ | |
| **AD-Copilot** extends Qwen2.5-VL with a novel **comparison-aware visual encoder** that generates | |
| special comparison tokens capturing differences between a reference image and a test image, | |
| achieving state-of-the-art results on industrial anomaly detection benchmarks. | |
| **Two modes:** Upload both images for comparison-based inspection, or just a test image for single-image tasks (counting, OCR, etc.). | |
| [[Paper]](https://arxiv.org/abs/2603.13779v1) | [[Code]](https://github.com/jam-cc/AD-Copilot) | [[Model]](https://huggingface.co/jiang-cc/AD-Copilot) | |
| """ | |
| EXAMPLES = [ | |
| # 1. Anomaly Discrimination (yes/no) β bottle contamination | |
| [ | |
| "examples/bottle_good.jpg", | |
| "examples/bottle_contamination.jpg", | |
| "The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.", | |
| 128, | |
| ], | |
| # 2. Defect Description β cable cut | |
| [ | |
| "examples/cable_good.jpg", | |
| "examples/cable_cut.jpg", | |
| "The first image is a normal sample. Compared with the first image, please describe the anomaly in the second image in detail.", | |
| 256, | |
| ], | |
| # 3. Fine-grained Defect Localization β bottle contamination | |
| [ | |
| "examples/bottle_good.jpg", | |
| "examples/bottle_contamination.jpg", | |
| "The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.", | |
| 256, | |
| ], | |
| # 4. Defect Localization β PCB wrong position | |
| [ | |
| "examples/pcb_good.jpg", | |
| "examples/pcb_defect.jpg", | |
| "The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.", | |
| 256, | |
| ], | |
| # 5. Object Counting (single image) β candle | |
| [ | |
| None, | |
| "examples/candle_count.jpg", | |
| "How many candles are in this image? For each candle, give its bounding box as {\"bbox_2d\": [x1,y1,x2,y2], \"label\": \"candle_1\"}, numbering them 1, 2, 3, etc. State the total count at the end.", | |
| 512, | |
| ], | |
| # 7. Industrial OCR (single image) β drink bottle | |
| [ | |
| None, | |
| "examples/drink_bottle_ocr.jpg", | |
| "Please read all text and labels on this product.", | |
| 256, | |
| ], | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ref_img = gr.Image( | |
| label="Reference (Good) Image (optional)", | |
| type="pil", | |
| height=300, | |
| ) | |
| with gr.Column(scale=1): | |
| test_img = gr.Image( | |
| label="Test Image", | |
| type="pil", | |
| height=300, | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=16, | |
| maximum=1024, | |
| value=128, | |
| step=16, | |
| label="Max New Tokens", | |
| ) | |
| run_btn = gr.Button("Run (5s)", variant="primary", scale=2) | |
| run_long_btn = gr.Button("Run Long (10s)", variant="secondary", scale=1) | |
| output = gr.Textbox(label="Model Output", lines=4) | |
| vis_output = gr.Image(label="Detection Visualization") | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[ref_img, test_img, prompt, max_tokens], | |
| outputs=[output, vis_output], | |
| ) | |
| run_long_btn.click( | |
| fn=predict_long, | |
| inputs=[ref_img, test_img, prompt, max_tokens], | |
| outputs=[output, vis_output], | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[ref_img, test_img, prompt, max_tokens], | |
| outputs=[output, vis_output], | |
| fn=predict, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |