AD-Copilot / app.py
jiang-cc's picture
fix: PCB use pcb3/bad009 (Wrong Position), candle prompt with numbered labels and total count
6e9b876 verified
"""
AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model
"""
import json
import os
import re
import time
import traceback
import spaces
import gradio as gr
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image, ImageDraw, ImageFont
# ---------------------------------------------------------------------------
# Model loading (happens once at Space startup; weights stay on CPU until
# @spaces.GPU moves them to GPU on demand)
# ---------------------------------------------------------------------------
MODEL_ID = "jiang-cc/AD-Copilot"
processor = AutoProcessor.from_pretrained(
MODEL_ID,
min_pixels=64 * 28 * 28,
max_pixels=1280 * 28 * 28,
trust_remote_code=True,
)
try:
import flash_attn # noqa: F401
_attn_impl = "flash_attention_2"
except ImportError:
_attn_impl = "sdpa"
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation=_attn_impl,
trust_remote_code=True,
).to("cuda").eval()
print(f"[AD-Copilot] Attention: {_attn_impl} | Device: {model.device}", flush=True)
# ---------------------------------------------------------------------------
# BBox visualization
# ---------------------------------------------------------------------------
COLORS = [
"#FF4444", "#44AA44", "#4488FF", "#FF8800",
"#AA44FF", "#00CCCC", "#FF44AA", "#88AA00",
]
def parse_bboxes(text):
"""Try to extract bbox JSON from model output."""
pattern = r'```(?:json)?\s*(\[.*?\])\s*```'
match = re.search(pattern, text, re.DOTALL)
if match:
raw = match.group(1)
else:
match = re.search(r'(\[\s*\{.*?\}\s*\])', text, re.DOTALL)
if match:
raw = match.group(1)
else:
return None
try:
bboxes = json.loads(raw)
if isinstance(bboxes, list) and len(bboxes) > 0 and "bbox_2d" in bboxes[0]:
# Normalize: accept both "label" and "bbox_label"
for b in bboxes:
if "label" not in b and "bbox_label" in b:
b["label"] = b.pop("bbox_label")
return bboxes
except json.JSONDecodeError:
pass
return None
def draw_bboxes(image, bboxes):
"""Draw bounding boxes with labels on image."""
img = image.copy()
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13)
except (IOError, OSError):
font = ImageFont.load_default()
small_font = font
for i, bbox_info in enumerate(bboxes):
bbox = bbox_info.get("bbox_2d", [])
label = bbox_info.get("label", f"defect_{i}")
if len(bbox) != 4:
continue
x1, y1, x2, y2 = bbox
color = COLORS[i % len(COLORS)]
for w in range(3):
draw.rectangle([x1 - w, y1 - w, x2 + w, y2 + w], outline=color)
text_bbox = draw.textbbox((0, 0), label, font=small_font)
tw = text_bbox[2] - text_bbox[0] + 8
th = text_bbox[3] - text_bbox[1] + 6
label_y = max(0, y1 - th - 2)
draw.rectangle([x1, label_y, x1 + tw, label_y + th], fill=color)
draw.text((x1 + 4, label_y + 2), label, fill="white", font=small_font)
return img
# ---------------------------------------------------------------------------
# Inference (supports both single-image and paired-image modes)
# ---------------------------------------------------------------------------
def _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=None):
"""Core inference logic shared by both predict functions."""
t_gpu_ready = time.time() # GPU is allocated by this point
has_ref = reference_image is not None
has_test = test_image is not None
if not has_ref and not has_test:
return "Please upload at least one image.", None
try:
t0 = time.time()
max_new_tokens = int(max_new_tokens)
content = []
if has_ref and has_test:
ref = reference_image.copy()
tst = test_image.copy()
ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
content.append({"type": "image", "image": ref})
content.append({"type": "image", "image": tst})
vis_source = tst
elif has_test:
tst = test_image.copy()
tst.thumbnail((512, 512), Image.Resampling.LANCZOS)
content.append({"type": "image", "image": tst})
vis_source = tst
else:
ref = reference_image.copy()
ref.thumbnail((512, 512), Image.Resampling.LANCZOS)
content.append({"type": "image", "image": ref})
vis_source = ref
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
t_preprocess = time.time()
generated_ids = model.generate(
**inputs, max_new_tokens=max_new_tokens, do_sample=False
)
t_generate = time.time()
generated_ids_trimmed = [
out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)
]
n_tokens = generated_ids.shape[1] - inputs.input_ids.shape[1]
output = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
bboxes = parse_bboxes(output)
vis_image = None
if bboxes:
vis_image = draw_bboxes(vis_source, bboxes)
gpu_load_t = t_gpu_ready - _t_enter if _t_enter else 0
prep_t = t_preprocess - t0
gen_t = t_generate - t_preprocess
tps = n_tokens / gen_t if gen_t > 0 else 0
parts = []
if gpu_load_t > 0.5:
parts.append(f"GPU Load: {gpu_load_t:.1f}s")
parts.append(f"Preprocess: {prep_t:.1f}s")
parts.append(f"Generate: {gen_t:.1f}s ({n_tokens} tokens, {tps:.1f} tok/s)")
output += f"\n\n---\n[{_attn_impl}] {' | '.join(parts)}"
return output, vis_image
except Exception as e:
tb = traceback.format_exc()
print(tb, flush=True)
return f"Error:\n{tb}", None
def _wrap_predict(fn):
def wrapper(reference_image, test_image, prompt, max_new_tokens):
t_enter = time.time()
return fn(reference_image, test_image, prompt, max_new_tokens, _t_enter=t_enter)
return wrapper
@spaces.GPU(duration=5)
def _predict_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None):
return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter)
@spaces.GPU(duration=10)
def _predict_long_gpu(reference_image, test_image, prompt, max_new_tokens, _t_enter=None):
return _run_inference(reference_image, test_image, prompt, max_new_tokens, _t_enter=_t_enter)
predict = _wrap_predict(_predict_gpu)
predict_long = _wrap_predict(_predict_long_gpu)
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
TITLE = "AD-Copilot: Comparison-Aware Anomaly Detection"
DESCRIPTION = """
**AD-Copilot** extends Qwen2.5-VL with a novel **comparison-aware visual encoder** that generates
special comparison tokens capturing differences between a reference image and a test image,
achieving state-of-the-art results on industrial anomaly detection benchmarks.
**Two modes:** Upload both images for comparison-based inspection, or just a test image for single-image tasks (counting, OCR, etc.).
[[Paper]](https://arxiv.org/abs/2603.13779v1) | [[Code]](https://github.com/jam-cc/AD-Copilot) | [[Model]](https://huggingface.co/jiang-cc/AD-Copilot)
"""
EXAMPLES = [
# 1. Anomaly Discrimination (yes/no) β€” bottle contamination
[
"examples/bottle_good.jpg",
"examples/bottle_contamination.jpg",
"The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.",
128,
],
# 2. Defect Description β€” cable cut
[
"examples/cable_good.jpg",
"examples/cable_cut.jpg",
"The first image is a normal sample. Compared with the first image, please describe the anomaly in the second image in detail.",
256,
],
# 3. Fine-grained Defect Localization β€” bottle contamination
[
"examples/bottle_good.jpg",
"examples/bottle_contamination.jpg",
"The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.",
256,
],
# 4. Defect Localization β€” PCB wrong position
[
"examples/pcb_good.jpg",
"examples/pcb_defect.jpg",
"The first image is a normal sample. Please locate the defects within the second image with bounding box in JSON format.",
256,
],
# 5. Object Counting (single image) β€” candle
[
None,
"examples/candle_count.jpg",
"How many candles are in this image? For each candle, give its bounding box as {\"bbox_2d\": [x1,y1,x2,y2], \"label\": \"candle_1\"}, numbering them 1, 2, 3, etc. State the total count at the end.",
512,
],
# 7. Industrial OCR (single image) β€” drink bottle
[
None,
"examples/drink_bottle_ocr.jpg",
"Please read all text and labels on this product.",
256,
],
]
with gr.Blocks(theme=gr.themes.Soft(), title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
ref_img = gr.Image(
label="Reference (Good) Image (optional)",
type="pil",
height=300,
)
with gr.Column(scale=1):
test_img = gr.Image(
label="Test Image",
type="pil",
height=300,
)
prompt = gr.Textbox(
label="Prompt",
value="The first image is a normal sample. Is there any anomaly in the second image? A. Yes B. No. Please answer the letter only.",
lines=2,
)
with gr.Row():
max_tokens = gr.Slider(
minimum=16,
maximum=1024,
value=128,
step=16,
label="Max New Tokens",
)
run_btn = gr.Button("Run (5s)", variant="primary", scale=2)
run_long_btn = gr.Button("Run Long (10s)", variant="secondary", scale=1)
output = gr.Textbox(label="Model Output", lines=4)
vis_output = gr.Image(label="Detection Visualization")
run_btn.click(
fn=predict,
inputs=[ref_img, test_img, prompt, max_tokens],
outputs=[output, vis_output],
)
run_long_btn.click(
fn=predict_long,
inputs=[ref_img, test_img, prompt, max_tokens],
outputs=[output, vis_output],
)
gr.Examples(
examples=EXAMPLES,
inputs=[ref_img, test_img, prompt, max_tokens],
outputs=[output, vis_output],
fn=predict,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()