File size: 3,503 Bytes
dd33601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable, List

import numpy as np
import torch
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torchvision import transforms

from data.dct import DCT_base_Rec_Module
from models import AIDE as build_aide_model


IMAGE_SIZE = 256
TO_TENSOR = transforms.ToTensor()
NORMALIZE_AND_RESIZE = transforms.Compose(
    [
        transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)


def build_aide_input_from_pil(image: Image.Image, dct_module: DCT_base_Rec_Module) -> torch.Tensor:
    image = image.convert("RGB")
    image_tensor = TO_TENSOR(image)
    x_minmin, x_maxmax, x_minmin1, x_maxmax1 = dct_module(image_tensor)

    x_0 = NORMALIZE_AND_RESIZE(image_tensor)
    x_minmin = NORMALIZE_AND_RESIZE(x_minmin)
    x_maxmax = NORMALIZE_AND_RESIZE(x_maxmax)
    x_minmin1 = NORMALIZE_AND_RESIZE(x_minmin1)
    x_maxmax1 = NORMALIZE_AND_RESIZE(x_maxmax1)

    return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0)


def load_model(
    repo_dir: str | Path,
    device: str | None = None,
    weights_name: str = "model.safetensors",
) -> torch.nn.Module:
    repo_dir = Path(repo_dir)
    weights_path = repo_dir / weights_name
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    model = build_aide_model(resnet_path=None, convnext_path=None)
    state_dict = load_safetensors(str(weights_path))
    model.load_state_dict(state_dict, strict=True)
    model.to(device)
    model.eval()
    return model


@torch.inference_mode()
def predict_pil_images(
    model: torch.nn.Module,
    images: Iterable[Image.Image],
    device: str | None = None,
) -> List[dict]:
    device = device or next(model.parameters()).device.type
    dct_module = DCT_base_Rec_Module()
    batch = torch.stack([build_aide_input_from_pil(img, dct_module) for img in images], dim=0).to(device)
    logits = model(batch)
    probs = torch.softmax(logits, dim=-1).cpu().numpy()

    outputs = []
    for prob in probs:
        real_prob = float(prob[0])
        fake_prob = float(prob[1])
        label = "fake" if fake_prob >= real_prob else "real"
        outputs.append(
            {
                "label": label,
                "real_probability": round(real_prob, 6),
                "fake_probability": round(fake_prob, 6),
            }
        )
    return outputs


def _load_images(paths: Iterable[str]) -> List[Image.Image]:
    return [Image.open(path).convert("RGB") for path in paths]


def main() -> None:
    parser = argparse.ArgumentParser(description="Run AIDE image detector inference.")
    parser.add_argument("--repo_dir", type=str, default=".", help="Local path to the model repository.")
    parser.add_argument("--image", type=str, nargs="+", required=True, help="One or more image paths.")
    parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
    args = parser.parse_args()

    model = load_model(args.repo_dir, device=args.device)
    images = _load_images(args.image)
    predictions = predict_pil_images(model, images, device=args.device)

    for image_path, prediction in zip(args.image, predictions):
        print(
            {
                "image": str(image_path),
                **prediction,
            }
        )


if __name__ == "__main__":
    main()