RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification

RetinaSense-ViT v4 is a hybrid deep learning model that fuses ViT-Base/16 and EfficientNet-B3 feature representations for automated retinal disease screening from color fundus photographs. It classifies images into 5 disease categories with calibrated confidence estimates, uncertainty quantification, and explainable attention maps.

HuggingFace Repository: tanishq74/retinasense-vit


Model Details

Property Value
Architecture HybridRetinaModel (ViT-Base/16 + EfficientNet-B3 fusion)
Parameters 97.8M
Input Size 224 x 224 x 3 (RGB fundus photograph)
Output 5-class probability distribution
Framework PyTorch
License MIT

Architecture Overview

The HybridRetinaModel extracts complementary features from two pretrained backbones and fuses them through concatenation followed by a classification MLP:

Input Image (224x224x3)
    |
    +---> EfficientNet-B3 ---> 1536-dim features
    |                                |
    +---> ViT-Base/16 ---------> 768-dim features
                                     |
                          Concatenation (2304-dim)
                                     |
                              MLP Classifier
                                     |
                           5-class predictions
  • EfficientNet-B3 branch: Extracts local texture and fine-grained pathological features (1536-dim pooled features)
  • ViT-Base/16 branch: Captures global spatial relationships and long-range dependencies (768-dim CLS token)
  • Fusion layer: Concatenated 2304-dim representation passed through a multi-layer perceptron with batch normalization and dropout
  • Transfer learning: Initialized from v3 pretrained ViT (best_model.pth, 82.59% val acc) and EfficientNet-B3 (efficientnet_b3.pth, 71.1% val acc)

Disease Classes

Label Class Description
0 Normal No retinal disease detected
1 Diabetes/DR Diabetic retinopathy (microaneurysms, hemorrhages, exudates)
2 Glaucoma Optic disc cupping and glaucomatous changes
3 Cataract Lens opacity affecting fundus image clarity
4 AMD Age-related macular degeneration (drusen, atrophy)

Training

Dataset

The training dataset was constructed by merging two public retinal imaging datasets:

Source Original Size After Cleaning
APTOS-2019 3,662 images Deduplicated and quality-filtered
ODIR-5K 6,392 images Deduplicated and quality-filtered
Merged (raw) 8,905 images --
After deduplication -- 5,270 images
After balancing -- 10,000 images (2,000/class)

Balancing was achieved through a combination of undersampling over-represented classes and augmentation-based oversampling of under-represented classes to reach exactly 2,000 images per class.

Data Splits

Split Samples Purpose
Train 7,038 Model training
Validation 1,476 Hyperparameter tuning and early stopping
Test 1,486 Final held-out evaluation

All splits are patient-aware to prevent data leakage -- no patient appears in more than one split.

Preprocessing Pipeline

The preprocessing pipeline must be replicated exactly for correct inference:

Raw Image -> Crop Black Borders -> Resize 224x224 -> CLAHE (L-channel) -> Circular Mask -> Normalize
Step Details
Border Crop Remove dark padding (pixels with brightness < 7)
Resize 224 x 224 with cv2.INTER_AREA interpolation
CLAHE Applied to L-channel in LAB color space (clipLimit=2.0, tileGridSize=8x8)
Circular Mask Zero out pixels outside a centered circle (radius = 0.48 x min dimension)
Normalize mean=[0.4298, 0.2784, 0.1559], std=[0.2857, 0.2065, 0.1465]

Training Configuration

Parameter Value
Optimizer AdamW with Layer-wise Learning Rate Decay (LLRD)
Scheduler OneCycleLR with cosine annealing
Loss Function Focal Loss (gamma=2.0) with label smoothing (0.1)
Batch Size 32
Learning Rate 1e-4
Epochs 40
Augmentation MixUp, CutMix, albumentations (flips, rotation, color jitter, elastic transforms)
Regularization Dropout, Stochastic Weight Averaging (SWA)
Fine-tuning Lesion attention training with GradCAM-guided attention loss

Results

5-Fold Stratified Cross-Validation (10,000 images)

Metric Mean Std
Accuracy 91.13% +/- 0.55%
Macro F1 0.910 +/- 0.006
Macro AUC 0.986 +/- 0.001

Held-Out Test Set (1,486 samples)

Metric Value
Accuracy 80.9% (82.0% with optimized thresholds)
Macro F1 0.813 (0.822 with optimized thresholds)
Macro AUC 0.969
Cohen's Kappa 0.761
MC Dropout Accuracy @ 90% retention 86.0%

Per-Class Test Set Performance

Class F1 AUC Precision Recall
Normal 0.69 0.926 0.573 0.857
Diabetes/DR 0.78 0.965 0.844 0.726
Glaucoma 0.78 0.981 0.925 0.670
Cataract 0.95 0.997 0.940 0.966
AMD 0.87 0.977 0.917 0.827

Improvement over v3

Metric v3 (Ensemble) v4 (Hybrid Fusion) Delta
Accuracy 74.7% 82.0% +7.3%
Macro F1 0.712 0.822 +0.110
Macro AUC 0.951 0.969 +0.018

Usage

Installation

pip install torch torchvision timm opencv-python-headless numpy huggingface_hub

Download Model Weights

from huggingface_hub import hf_hub_download
import shutil, os

os.makedirs("weights", exist_ok=True)
for fname in ["best_model.pth", "temperature.json", "thresholds.json"]:
    path = hf_hub_download(repo_id="tanishq74/retinasense-vit", filename=fname)
    shutil.copy(path, f"weights/{fname}")

Inference Example

import torch
import torch.nn as nn
import timm
import cv2
import json
import numpy as np
from torchvision import transforms

# --- Model Definition ---
class HybridRetinaModel(nn.Module):
    def __init__(self, num_classes=5, drop_rate=0.3):
        super().__init__()
        # EfficientNet-B3 branch (1536-dim)
        self.efficientnet = timm.create_model(
            'efficientnet_b3', pretrained=False, num_classes=0
        )
        # ViT-Base/16 branch (768-dim)
        self.vit = timm.create_model(
            'vit_base_patch16_224', pretrained=False, num_classes=0
        )
        # Fusion MLP: 1536 + 768 = 2304 -> num_classes
        fusion_dim = 2304
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(drop_rate * 0.67),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        eff_features = self.efficientnet(x)   # (B, 1536)
        vit_features = self.vit(x)            # (B, 768)
        fused = torch.cat([eff_features, vit_features], dim=1)  # (B, 2304)
        return self.classifier(fused)

# --- Preprocessing ---
MEAN = [0.4298, 0.2784, 0.1559]
STD = [0.2857, 0.2065, 0.1465]
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']

normalize = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

def preprocess(img_path):
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    # Crop black borders
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    mask = gray > 7
    rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
    if rows.any() and cols.any():
        r0, r1 = np.where(rows)[0][[0, -1]]
        c0, c1 = np.where(cols)[0][[0, -1]]
        img = img[r0:r1+1, c0:c1+1]
    # Resize
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
    # CLAHE on L-channel
    lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    lab[:, :, 0] = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(lab[:, :, 0])
    img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
    # Circular mask
    h, w = img.shape[:2]
    cmask = np.zeros((h, w), dtype=np.uint8)
    cv2.circle(cmask, (w // 2, h // 2), int(min(h, w) * 0.48), 255, -1)
    img = cv2.bitwise_and(img, img, mask=cmask)
    return normalize(np.clip(img, 0, 255).astype(np.uint8)).unsqueeze(0)

# --- Load Model ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridRetinaModel(num_classes=5).to(device)

checkpoint = torch.load('weights/best_model.pth', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Load temperature for calibrated probabilities
with open('weights/temperature.json') as f:
    temperature = json.load(f)['temperature']

# Load per-class optimized thresholds (optional)
with open('weights/thresholds.json') as f:
    thresholds = json.load(f)

# --- Run Inference ---
x = preprocess('your_fundus_image.jpg').to(device)

with torch.no_grad():
    logits = model(x)
    probs = torch.softmax(logits / temperature, dim=1).cpu().numpy()[0]

pred_class = CLASS_NAMES[probs.argmax()]
confidence = probs.max()

print(f"Prediction: {pred_class} ({confidence * 100:.1f}%)")
print("\nClass Probabilities:")
for i, name in enumerate(CLASS_NAMES):
    print(f"  {name}: {probs[i] * 100:.1f}%")

MC Dropout Uncertainty Estimation

def mc_dropout_predict(model, x, n_passes=30):
    """Enable dropout at test time for uncertainty estimation."""
    model.train()  # Enable dropout
    predictions = []
    with torch.no_grad():
        for _ in range(n_passes):
            logits = model(x)
            probs = torch.softmax(logits / temperature, dim=1)
            predictions.append(probs.cpu().numpy())
    model.eval()

    predictions = np.array(predictions)  # (n_passes, batch, classes)
    mean_probs = predictions.mean(axis=0)
    epistemic_uncertainty = predictions.var(axis=0).sum(axis=-1)  # Inter-pass variance
    return mean_probs, epistemic_uncertainty

mean_probs, uncertainty = mc_dropout_predict(model, x)
print(f"Prediction: {CLASS_NAMES[mean_probs[0].argmax()]}")
print(f"Epistemic Uncertainty: {uncertainty[0]:.4f}")

Features

GradCAM Attention Visualization

Generate class-discriminative heatmaps highlighting the retinal regions most relevant to the model's prediction. Useful for clinical interpretability and verifying that the model attends to pathologically relevant structures.

MC Dropout Uncertainty Estimation

Monte Carlo Dropout with 30 stochastic forward passes provides calibrated epistemic uncertainty estimates. At 90% retention (discarding the most uncertain 10% of predictions), accuracy improves from 80.9% to 86.0%.

Temperature-Calibrated Probabilities

Post-hoc temperature scaling calibrates the model's softmax outputs so that predicted confidence values more closely match empirical accuracy.

Per-Class Optimized Decision Thresholds

Class-specific decision thresholds optimized on the validation set improve macro F1 from 0.813 to 0.822 on the held-out test set.

FAISS-Based Similar Case Retrieval

A FAISS index built from the model's 2304-dimensional fusion embeddings enables fast nearest-neighbor retrieval of visually and diagnostically similar cases from the training set, supporting clinical decision-making through case-based reasoning.


Files in Repository

File / Directory Description
best_model.pth Trained HybridRetinaModel checkpoint (v4)
temperature.json Calibrated temperature scaling parameter
thresholds.json Per-class optimized decision thresholds
evaluation/ Evaluation outputs: confusion matrices, ROC curves, per-class metrics, calibration plots
retrieval/ FAISS index and retrieval scripts for similar case lookup
kfold/ 5-fold cross-validation results and per-fold metrics
models/hybrid_retina_model.py Model architecture definition
training/retinasense_v4.py Main training script
training/kfold_cv.py 5-fold cross-validation script
training/lesion_attention_training.py GradCAM-guided attention fine-tuning
evaluation/eval_dashboard.py Comprehensive evaluation dashboard
retrieval/build_index.py FAISS index construction
retrieval/query_index.py Similar case retrieval interface

Limitations and Ethical Considerations

Technical Limitations

  • Normal class has lower F1 (0.69): Early-stage disease presentations overlap visually with normal fundus images, leading to lower precision for the Normal class. Clinically, this means the model errs on the side of flagging images for review rather than missing disease.
  • Dataset scope: Trained exclusively on APTOS-2019 and ODIR-5K datasets. Performance may degrade on fundus images from different camera systems, patient demographics, or imaging protocols not represented in the training data.
  • Single-label classification: Each image receives one predicted label. Co-morbid conditions (e.g., concurrent DR and glaucoma) are not modeled.
  • Cross-validation vs. test gap: The 5-fold CV accuracy (91.1%) is higher than the held-out test accuracy (82.0%), which may reflect distribution differences between augmented training data and real test images.

Intended Use

This model is intended for research and educational purposes only. It may serve as a screening aid or decision-support tool in research settings, but it is not a medical device and has not been validated for clinical deployment.

Out-of-Scope Uses

  • Clinical diagnosis without ophthalmologist verification
  • Deployment as a standalone screening tool in any healthcare setting
  • Use on non-fundus images or imaging modalities other than color fundus photography
  • Medico-legal decision-making

Citation

@misc{retinasense_v4_2026,
  title={RetinaSense-ViT v4: Hybrid Vision Transformer for Retinal Disease Classification},
  author={Tanishq Tamarkar},
  year={2026},
  url={https://huggingface.co/tanishq74/retinasense-vit},
  note={ViT-Base/16 + EfficientNet-B3 hybrid fusion model, 5-class retinal disease classification}
}

License

MIT License


Disclaimer

This is a research prototype for AI-assisted retinal screening. It is NOT a certified medical device and should NOT be used for clinical decision-making without independent verification by a qualified ophthalmologist. All predictions are probabilistic estimates and may be incorrect. The authors assume no liability for any clinical decisions made based on this model's outputs.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Evaluation results

  • Test Accuracy on APTOS-2019 + ODIR-5K (10,000 balanced images)
    self-reported
    0.820
  • Macro AUC on APTOS-2019 + ODIR-5K (10,000 balanced images)
    self-reported
    0.969
  • Macro F1 on APTOS-2019 + ODIR-5K (10,000 balanced images)
    self-reported
    0.822
  • 5-Fold CV Accuracy on APTOS-2019 + ODIR-5K (10,000 balanced images)
    self-reported
    0.911
  • 5-Fold CV Macro F1 on APTOS-2019 + ODIR-5K (10,000 balanced images)
    self-reported
    0.910