πŸ‘οΈ OCT Retinal AI β€” Hybrid CNN-Transformer Weights

Automated classification of retinal OCT scans into four disease categories: CNV, DME, DRUSEN, and NORMAL. Built as part of an MSc dissertation at Newcastle University (2025–26), with a focus on clinical safety and edge deployment.

Live demo: πŸ€— HuggingFace Space
GitHub: Animesh-Kr/Human-Eye-Disease-Prediction
Zenodo: 10.5281/zenodo.19224303


Model Details

Property Value
Architecture EfficientNetV2L + 4-Block MHA Transformer + XGBoost
Input resolution 224 Γ— 224
Number of classes 4 (CNV, DME, DRUSEN, NORMAL)
Dataset Kermany et al., Cell 2018
Test Accuracy 95.43% Β± 0.27% (5-seed)
Macro AUC-ROC 0.9941 Β± 0.0006
Macro F1 0.9244 Β± 0.0047
ECE (calibrated) 0.0024 Β± 0.0005
Parameters ~120M
Training framework TensorFlow 2.19 / Keras 3 / XGBoost
Edge model ONNX FP32, 237 MB

Evaluation Results (5-Seed Validation)

Model Accuracy Macro AUC-ROC Macro F1 ECE
Baseline CNN (EffNetV2L + Dense) 89.12% 0.9410 0.8642 0.0203
Hybrid CNN-Transformer-XGBoost 95.43% 0.9941 0.9244 0.0024

Statistical Significance

  • McNemar's test: p < 0.0001 on every seed β€” the improvement over the Dense baseline is not attributable to random variation.
  • Temperature scaling: ECE drops from ~0.0203 to 0.0024 (12-fold reduction), which means the confidence scores are actually reliable rather than systematically overconfident.
  • MC Dropout: 20 stochastic forward passes; Οƒ_max > 0.15 routes the scan to specialist review.
  • OOD Detection: Mahalanobis distance at the 97th percentile correctly rejects non-retinal and corrupted inputs before classification runs.

Edge Node Benchmark

The 2.07 GB Keras master model was converted to a 237 MB FP32 ONNX graph for edge deployment. This removes the TensorFlow runtime dependency and achieves consistent sub-70ms latency on CPU.

Per-Scan Latency (Intel CPU, Batch Size 1, 50 runs)

Scan Prediction Confidence Latency
DRUSEN DRUSEN βœ… 92.72% 61.96 ms
DME DME βœ… 81.20% 62.60 ms
CNV CNV βœ… 91.46% 64.31 ms
NORMAL DRUSEN ⚠️ 66.85% 62.82 ms
Global Average β€” β€” ~62.9 ms

The NORMAL scan misclassified as DRUSEN at 66.85% confidence is the canonical example of why the MC Dropout uncertainty flag exists. In this case Οƒ_max exceeded the 0.15 threshold, so the scan would be flagged for specialist review rather than silently misclassified.

Model Size Reduction

Format Size vs. Master
Keras .keras (master) 2,070 MB baseline
ONNX FP32 (edge node) 237 MB ~88% smaller

The 88% reduction was achieved through FP32 ONNX export via tf2onnx (opset 17), with mixed-precision FP16 tensors cast to FP32 for runtime compatibility on CPU execution providers.


Files in This Repository

File Size Description
Final_CNN_Transformer.keras 2,070 MB Master model (mixed precision, full TF/Keras stack)
human_eye_fp32.onnx 237 MB Edge node (FP32 ONNX, CPU-deployable, ~88% smaller)
Final_XGBoost_Hybrid.json 2.4 MB XGBoost classification head
ood_train_mean.npy 0.6 KB Per-class training distribution means
ood_cov_inv.npy 524 KB Inverse covariance matrix for Mahalanobis OOD
ood_threshold.npy 0.1 KB 97th percentile OOD decision threshold
temperature.npy 0.1 KB Learned temperature scalar (T β‰ˆ 1.05)
data_splits/train_indices.npy 490 KB Training split indices (seed-fixed)
data_splits/val_indices.npy 61 KB Validation split indices
data_splits/test_indices.npy 61 KB Test split indices

Usage

Option A β€” Edge Inference (ONNX, CPU-only, no GPU needed)

import onnxruntime as ort
import numpy as np
import cv2

# Load the 237 MB edge model
sess       = ort.InferenceSession("human_eye_fp32.onnx",
                                   providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
CLASSES    = ["CNV", "DME", "DRUSEN", "NORMAL"]

def predict(img_path):
    img = cv2.imread(img_path)
    img = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (224, 224))
    img = (img.astype(np.float32) / 255.0
           - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
    img = np.expand_dims(img, axis=0)
    logits = sess.run(None, {input_name: img})[0]
    return CLASSES[np.argmax(logits[0])], float(np.max(logits[0]))

pred, conf = predict("scan.jpg")
print(f"{pred}  ({conf*100:.2f}%)")

Option B β€” Download and run locally

from huggingface_hub import hf_hub_download
import os

os.makedirs("models", exist_ok=True)
files = ["Final_CNN_Transformer.keras", "human_eye_fp32.onnx",
         "Final_XGBoost_Hybrid.json", "ood_train_mean.npy",
         "ood_cov_inv.npy", "ood_threshold.npy", "temperature.npy"]

for f in files:
    hf_hub_download(repo_id="animeshakr/oct-retinal-weights",
                    filename=f, local_dir="models/")

Architecture Details

Phase A β€” CNN + Transformer backbone

EfficientNetV2-Large (pretrained ImageNet-21k) extracts spatial features. Blocks 1–5 are frozen; Block 6+ is fine-tuned during Phase B. The 7Γ—7Γ—1280 bottleneck is reshaped into 49 patch tokens, projected to 256 dimensions via a trainable linear layer, and augmented with learnable positional encodings. Four Multi-Head Attention blocks (16 heads, key_dim=16) capture long-range retinal layer dependencies. GlobalAveragePooling1D produces the 256-dimensional embedding z ∈ ℝ²⁡⁢. Hyperparameters were selected with Optuna TPE (10 trials): Ξ·=1.59Γ—10⁻⁴, dropout=0.38, Focal Loss Ξ³=1.36.

Phase B β€” XGBoost head

Frozen embeddings from all training samples are used to fit an XGBoost classifier (Optuna: n_estimators=300, max_depth=4, Ξ·=0.1, subsample=0.8). The gradient boosting head sharpens minority-class decision boundaries without synthetic oversampling.


Limitations

  • Single-scanner scope: Validated on the Kermany dataset collected on one OCT device type. Cross-scanner generalisation (e.g., Topcon vs. Zeiss Cirrus) is unvalidated and is the next planned benchmark.
  • Clinical use: This is a decision-support tool. It should not replace professional ophthalmological assessment, particularly for edge cases like the NORMAL/DRUSEN boundary.

Citation

@article{kumar2026oct,
  author  = {Kumar, Animesh A.},
  title   = {A Hybrid {CNN}-Transformer Framework for Retinal {OCT}
             Classification with Integrated Clinical Safety Mechanisms},
  year    = {2026},
  doi     = {10.5281/zenodo.19224303},
  note    = {Zenodo software archive, v1.0.0}
}

References

  • Kermany et al. (2018). Identifying Medical Diagnoses by Image-Based Deep Learning. Cell, 172(5), 1122–1131. https://doi.org/10.1016/j.cell.2018.02.010
  • Chen & Guestrin (2016). XGBoost: A Scalable Tree Boosting System. KDD. https://doi.org/10.1145/2939672.2939785
  • Guo et al. (2017). On Calibration of Modern Neural Networks. ICML, 1321–1330.
  • Lee et al. (2018). A Simple Unified Framework for Detecting OOD Samples and Adversarial Attacks. NeurIPS, 31.
  • He et al. (2019). On OCT Image Classification via Deep Learning. IEEE Photonics Journal, 11(5). https://doi.org/10.1109/JPHOT.2019.2934484
  • Li et al. (2021). Applications of Deep Learning in Fundus Images: A Review. Medical Image Analysis, 69, 101971.

Developed by Animesh Kumar β€” Newcastle University

Downloads last month
250
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using animeshakr/oct-retinal-weights 3