| | import gradio as gr |
| | from torch.nn.functional import softmax |
| | import torch |
| | from transformers import ViTFeatureExtractor |
| | from transformers import MobileViTFeatureExtractor |
| | from transformers import MobileViTForImageClassification |
| | from transformers import ViTForImageClassification |
| |
|
| |
|
| | def predict(model_type, inp): |
| |
|
| | if model_type == "ViT": |
| | model_name_or_path = './models/vit-base-garbage/' |
| | feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path) |
| | model = ViTForImageClassification.from_pretrained(model_name_or_path) |
| | elif model_type == "MobileViT": |
| | model_name_or_path = './models/apple/mobilevit-small-garbage/' |
| | feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path) |
| | model = MobileViTForImageClassification.from_pretrained(model_name_or_path) |
| | inputs = feature_extractor(inp, return_tensors="pt") |
| | LABELS = list(model.config.label2id.keys()) |
| | with torch.no_grad(): |
| | logits = model(**inputs) |
| | print(logits[0]) |
| | probability = torch.nn.functional.softmax(logits[0], dim=-1) |
| |
|
| | confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)} |
| | |
| | return confidences |
| |
|
| |
|
| | demo = gr.Interface(fn=predict, |
| | inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")], |
| | outputs=gr.outputs.Label(num_top_classes=3), |
| | examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]], |
| | ) |
| | |
| | demo.launch() |