RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification
RetinaSense-ViT v4 is a hybrid deep learning model that fuses ViT-Base/16 and EfficientNet-B3 feature representations for automated retinal disease screening from color fundus photographs. It classifies images into 5 disease categories with calibrated confidence estimates, uncertainty quantification, and explainable attention maps.
HuggingFace Repository: tanishq74/retinasense-vit
Model Details
| Property | Value |
|---|---|
| Architecture | HybridRetinaModel (ViT-Base/16 + EfficientNet-B3 fusion) |
| Parameters | 97.8M |
| Input Size | 224 x 224 x 3 (RGB fundus photograph) |
| Output | 5-class probability distribution |
| Framework | PyTorch |
| License | MIT |
Architecture Overview
The HybridRetinaModel extracts complementary features from two pretrained backbones and fuses them through concatenation followed by a classification MLP:
Input Image (224x224x3)
|
+---> EfficientNet-B3 ---> 1536-dim features
| |
+---> ViT-Base/16 ---------> 768-dim features
|
Concatenation (2304-dim)
|
MLP Classifier
|
5-class predictions
- EfficientNet-B3 branch: Extracts local texture and fine-grained pathological features (1536-dim pooled features)
- ViT-Base/16 branch: Captures global spatial relationships and long-range dependencies (768-dim CLS token)
- Fusion layer: Concatenated 2304-dim representation passed through a multi-layer perceptron with batch normalization and dropout
- Transfer learning: Initialized from v3 pretrained ViT (
best_model.pth, 82.59% val acc) and EfficientNet-B3 (efficientnet_b3.pth, 71.1% val acc)
Disease Classes
| Label | Class | Description |
|---|---|---|
| 0 | Normal | No retinal disease detected |
| 1 | Diabetes/DR | Diabetic retinopathy (microaneurysms, hemorrhages, exudates) |
| 2 | Glaucoma | Optic disc cupping and glaucomatous changes |
| 3 | Cataract | Lens opacity affecting fundus image clarity |
| 4 | AMD | Age-related macular degeneration (drusen, atrophy) |
Training
Dataset
The training dataset was constructed by merging two public retinal imaging datasets:
| Source | Original Size | After Cleaning |
|---|---|---|
| APTOS-2019 | 3,662 images | Deduplicated and quality-filtered |
| ODIR-5K | 6,392 images | Deduplicated and quality-filtered |
| Merged (raw) | 8,905 images | -- |
| After deduplication | -- | 5,270 images |
| After balancing | -- | 10,000 images (2,000/class) |
Balancing was achieved through a combination of undersampling over-represented classes and augmentation-based oversampling of under-represented classes to reach exactly 2,000 images per class.
Data Splits
| Split | Samples | Purpose |
|---|---|---|
| Train | 7,038 | Model training |
| Validation | 1,476 | Hyperparameter tuning and early stopping |
| Test | 1,486 | Final held-out evaluation |
All splits are patient-aware to prevent data leakage -- no patient appears in more than one split.
Preprocessing Pipeline
The preprocessing pipeline must be replicated exactly for correct inference:
Raw Image -> Crop Black Borders -> Resize 224x224 -> CLAHE (L-channel) -> Circular Mask -> Normalize
| Step | Details |
|---|---|
| Border Crop | Remove dark padding (pixels with brightness < 7) |
| Resize | 224 x 224 with cv2.INTER_AREA interpolation |
| CLAHE | Applied to L-channel in LAB color space (clipLimit=2.0, tileGridSize=8x8) |
| Circular Mask | Zero out pixels outside a centered circle (radius = 0.48 x min dimension) |
| Normalize | mean=[0.4298, 0.2784, 0.1559], std=[0.2857, 0.2065, 0.1465] |
Training Configuration
| Parameter | Value |
|---|---|
| Optimizer | AdamW with Layer-wise Learning Rate Decay (LLRD) |
| Scheduler | OneCycleLR with cosine annealing |
| Loss Function | Focal Loss (gamma=2.0) with label smoothing (0.1) |
| Batch Size | 32 |
| Learning Rate | 1e-4 |
| Epochs | 40 |
| Augmentation | MixUp, CutMix, albumentations (flips, rotation, color jitter, elastic transforms) |
| Regularization | Dropout, Stochastic Weight Averaging (SWA) |
| Fine-tuning | Lesion attention training with GradCAM-guided attention loss |
Results
5-Fold Stratified Cross-Validation (10,000 images)
| Metric | Mean | Std |
|---|---|---|
| Accuracy | 91.13% | +/- 0.55% |
| Macro F1 | 0.910 | +/- 0.006 |
| Macro AUC | 0.986 | +/- 0.001 |
Held-Out Test Set (1,486 samples)
| Metric | Value |
|---|---|
| Accuracy | 80.9% (82.0% with optimized thresholds) |
| Macro F1 | 0.813 (0.822 with optimized thresholds) |
| Macro AUC | 0.969 |
| Cohen's Kappa | 0.761 |
| MC Dropout Accuracy @ 90% retention | 86.0% |
Per-Class Test Set Performance
| Class | F1 | AUC | Precision | Recall |
|---|---|---|---|---|
| Normal | 0.69 | 0.926 | 0.573 | 0.857 |
| Diabetes/DR | 0.78 | 0.965 | 0.844 | 0.726 |
| Glaucoma | 0.78 | 0.981 | 0.925 | 0.670 |
| Cataract | 0.95 | 0.997 | 0.940 | 0.966 |
| AMD | 0.87 | 0.977 | 0.917 | 0.827 |
Improvement over v3
| Metric | v3 (Ensemble) | v4 (Hybrid Fusion) | Delta |
|---|---|---|---|
| Accuracy | 74.7% | 82.0% | +7.3% |
| Macro F1 | 0.712 | 0.822 | +0.110 |
| Macro AUC | 0.951 | 0.969 | +0.018 |
Usage
Installation
pip install torch torchvision timm opencv-python-headless numpy huggingface_hub
Download Model Weights
from huggingface_hub import hf_hub_download
import shutil, os
os.makedirs("weights", exist_ok=True)
for fname in ["best_model.pth", "temperature.json", "thresholds.json"]:
path = hf_hub_download(repo_id="tanishq74/retinasense-vit", filename=fname)
shutil.copy(path, f"weights/{fname}")
Inference Example
import torch
import torch.nn as nn
import timm
import cv2
import json
import numpy as np
from torchvision import transforms
# --- Model Definition ---
class HybridRetinaModel(nn.Module):
def __init__(self, num_classes=5, drop_rate=0.3):
super().__init__()
# EfficientNet-B3 branch (1536-dim)
self.efficientnet = timm.create_model(
'efficientnet_b3', pretrained=False, num_classes=0
)
# ViT-Base/16 branch (768-dim)
self.vit = timm.create_model(
'vit_base_patch16_224', pretrained=False, num_classes=0
)
# Fusion MLP: 1536 + 768 = 2304 -> num_classes
fusion_dim = 2304
self.classifier = nn.Sequential(
nn.Linear(fusion_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(drop_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(drop_rate * 0.67),
nn.Linear(256, num_classes),
)
def forward(self, x):
eff_features = self.efficientnet(x) # (B, 1536)
vit_features = self.vit(x) # (B, 768)
fused = torch.cat([eff_features, vit_features], dim=1) # (B, 2304)
return self.classifier(fused)
# --- Preprocessing ---
MEAN = [0.4298, 0.2784, 0.1559]
STD = [0.2857, 0.2065, 0.1465]
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
normalize = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
def preprocess(img_path):
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
# Crop black borders
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
mask = gray > 7
rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
if rows.any() and cols.any():
r0, r1 = np.where(rows)[0][[0, -1]]
c0, c1 = np.where(cols)[0][[0, -1]]
img = img[r0:r1+1, c0:c1+1]
# Resize
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
# CLAHE on L-channel
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
lab[:, :, 0] = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(lab[:, :, 0])
img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# Circular mask
h, w = img.shape[:2]
cmask = np.zeros((h, w), dtype=np.uint8)
cv2.circle(cmask, (w // 2, h // 2), int(min(h, w) * 0.48), 255, -1)
img = cv2.bitwise_and(img, img, mask=cmask)
return normalize(np.clip(img, 0, 255).astype(np.uint8)).unsqueeze(0)
# --- Load Model ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridRetinaModel(num_classes=5).to(device)
checkpoint = torch.load('weights/best_model.pth', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load temperature for calibrated probabilities
with open('weights/temperature.json') as f:
temperature = json.load(f)['temperature']
# Load per-class optimized thresholds (optional)
with open('weights/thresholds.json') as f:
thresholds = json.load(f)
# --- Run Inference ---
x = preprocess('your_fundus_image.jpg').to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits / temperature, dim=1).cpu().numpy()[0]
pred_class = CLASS_NAMES[probs.argmax()]
confidence = probs.max()
print(f"Prediction: {pred_class} ({confidence * 100:.1f}%)")
print("\nClass Probabilities:")
for i, name in enumerate(CLASS_NAMES):
print(f" {name}: {probs[i] * 100:.1f}%")
MC Dropout Uncertainty Estimation
def mc_dropout_predict(model, x, n_passes=30):
"""Enable dropout at test time for uncertainty estimation."""
model.train() # Enable dropout
predictions = []
with torch.no_grad():
for _ in range(n_passes):
logits = model(x)
probs = torch.softmax(logits / temperature, dim=1)
predictions.append(probs.cpu().numpy())
model.eval()
predictions = np.array(predictions) # (n_passes, batch, classes)
mean_probs = predictions.mean(axis=0)
epistemic_uncertainty = predictions.var(axis=0).sum(axis=-1) # Inter-pass variance
return mean_probs, epistemic_uncertainty
mean_probs, uncertainty = mc_dropout_predict(model, x)
print(f"Prediction: {CLASS_NAMES[mean_probs[0].argmax()]}")
print(f"Epistemic Uncertainty: {uncertainty[0]:.4f}")
Features
GradCAM Attention Visualization
Generate class-discriminative heatmaps highlighting the retinal regions most relevant to the model's prediction. Useful for clinical interpretability and verifying that the model attends to pathologically relevant structures.
MC Dropout Uncertainty Estimation
Monte Carlo Dropout with 30 stochastic forward passes provides calibrated epistemic uncertainty estimates. At 90% retention (discarding the most uncertain 10% of predictions), accuracy improves from 80.9% to 86.0%.
Temperature-Calibrated Probabilities
Post-hoc temperature scaling calibrates the model's softmax outputs so that predicted confidence values more closely match empirical accuracy.
Per-Class Optimized Decision Thresholds
Class-specific decision thresholds optimized on the validation set improve macro F1 from 0.813 to 0.822 on the held-out test set.
FAISS-Based Similar Case Retrieval
A FAISS index built from the model's 2304-dimensional fusion embeddings enables fast nearest-neighbor retrieval of visually and diagnostically similar cases from the training set, supporting clinical decision-making through case-based reasoning.
Files in Repository
| File / Directory | Description |
|---|---|
best_model.pth |
Trained HybridRetinaModel checkpoint (v4) |
temperature.json |
Calibrated temperature scaling parameter |
thresholds.json |
Per-class optimized decision thresholds |
evaluation/ |
Evaluation outputs: confusion matrices, ROC curves, per-class metrics, calibration plots |
retrieval/ |
FAISS index and retrieval scripts for similar case lookup |
kfold/ |
5-fold cross-validation results and per-fold metrics |
models/hybrid_retina_model.py |
Model architecture definition |
training/retinasense_v4.py |
Main training script |
training/kfold_cv.py |
5-fold cross-validation script |
training/lesion_attention_training.py |
GradCAM-guided attention fine-tuning |
evaluation/eval_dashboard.py |
Comprehensive evaluation dashboard |
retrieval/build_index.py |
FAISS index construction |
retrieval/query_index.py |
Similar case retrieval interface |
Limitations and Ethical Considerations
Technical Limitations
- Normal class has lower F1 (0.69): Early-stage disease presentations overlap visually with normal fundus images, leading to lower precision for the Normal class. Clinically, this means the model errs on the side of flagging images for review rather than missing disease.
- Dataset scope: Trained exclusively on APTOS-2019 and ODIR-5K datasets. Performance may degrade on fundus images from different camera systems, patient demographics, or imaging protocols not represented in the training data.
- Single-label classification: Each image receives one predicted label. Co-morbid conditions (e.g., concurrent DR and glaucoma) are not modeled.
- Cross-validation vs. test gap: The 5-fold CV accuracy (91.1%) is higher than the held-out test accuracy (82.0%), which may reflect distribution differences between augmented training data and real test images.
Intended Use
This model is intended for research and educational purposes only. It may serve as a screening aid or decision-support tool in research settings, but it is not a medical device and has not been validated for clinical deployment.
Out-of-Scope Uses
- Clinical diagnosis without ophthalmologist verification
- Deployment as a standalone screening tool in any healthcare setting
- Use on non-fundus images or imaging modalities other than color fundus photography
- Medico-legal decision-making
Citation
@misc{retinasense_v4_2026,
title={RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification},
author={Tanishq Tamarkar},
year={2026},
url={https://huggingface.co/tanishq74/retinasense-vit},
note={ViT-Base/16 + EfficientNet-B3 hybrid fusion model, 5-class retinal disease classification}
}
License
MIT License
Disclaimer
This is a research prototype for AI-assisted retinal screening. It is NOT a certified medical device and should NOT be used for clinical decision-making without independent verification by a qualified ophthalmologist. All predictions are probabilistic estimates and may be incorrect. The authors assume no liability for any clinical decisions made based on this model's outputs.
Evaluation results
- Test Accuracy on APTOS-2019 + ODIR-5K (10,000 balanced images)self-reported0.820
- Macro AUC on APTOS-2019 + ODIR-5K (10,000 balanced images)self-reported0.969
- Macro F1 on APTOS-2019 + ODIR-5K (10,000 balanced images)self-reported0.822
- 5-Fold CV Accuracy on APTOS-2019 + ODIR-5K (10,000 balanced images)self-reported0.911
- 5-Fold CV Macro F1 on APTOS-2019 + ODIR-5K (10,000 balanced images)self-reported0.910