ποΈ OCT Retinal AI β Hybrid CNN-Transformer Weights
Automated classification of retinal OCT scans into four disease categories: CNV, DME, DRUSEN, and NORMAL. Built as part of an MSc dissertation at Newcastle University (2025β26), with a focus on clinical safety and edge deployment.
Live demo: π€ HuggingFace Space
GitHub: Animesh-Kr/Human-Eye-Disease-Prediction
Zenodo: 10.5281/zenodo.19224303
Model Details
| Property | Value |
|---|---|
| Architecture | EfficientNetV2L + 4-Block MHA Transformer + XGBoost |
| Input resolution | 224 Γ 224 |
| Number of classes | 4 (CNV, DME, DRUSEN, NORMAL) |
| Dataset | Kermany et al., Cell 2018 |
| Test Accuracy | 95.43% Β± 0.27% (5-seed) |
| Macro AUC-ROC | 0.9941 Β± 0.0006 |
| Macro F1 | 0.9244 Β± 0.0047 |
| ECE (calibrated) | 0.0024 Β± 0.0005 |
| Parameters | ~120M |
| Training framework | TensorFlow 2.19 / Keras 3 / XGBoost |
| Edge model | ONNX FP32, 237 MB |
Evaluation Results (5-Seed Validation)
| Model | Accuracy | Macro AUC-ROC | Macro F1 | ECE |
|---|---|---|---|---|
| Baseline CNN (EffNetV2L + Dense) | 89.12% | 0.9410 | 0.8642 | 0.0203 |
| Hybrid CNN-Transformer-XGBoost | 95.43% | 0.9941 | 0.9244 | 0.0024 |
Statistical Significance
- McNemar's test: p < 0.0001 on every seed β the improvement over the Dense baseline is not attributable to random variation.
- Temperature scaling: ECE drops from ~0.0203 to 0.0024 (12-fold reduction), which means the confidence scores are actually reliable rather than systematically overconfident.
- MC Dropout: 20 stochastic forward passes; Ο_max > 0.15 routes the scan to specialist review.
- OOD Detection: Mahalanobis distance at the 97th percentile correctly rejects non-retinal and corrupted inputs before classification runs.
Edge Node Benchmark
The 2.07 GB Keras master model was converted to a 237 MB FP32 ONNX graph for edge deployment. This removes the TensorFlow runtime dependency and achieves consistent sub-70ms latency on CPU.
Per-Scan Latency (Intel CPU, Batch Size 1, 50 runs)
| Scan | Prediction | Confidence | Latency |
|---|---|---|---|
| DRUSEN | DRUSEN β | 92.72% | 61.96 ms |
| DME | DME β | 81.20% | 62.60 ms |
| CNV | CNV β | 91.46% | 64.31 ms |
| NORMAL | DRUSEN β οΈ | 66.85% | 62.82 ms |
| Global Average | β | β | ~62.9 ms |
The NORMAL scan misclassified as DRUSEN at 66.85% confidence is the canonical example of why the MC Dropout uncertainty flag exists. In this case Ο_max exceeded the 0.15 threshold, so the scan would be flagged for specialist review rather than silently misclassified.
Model Size Reduction
| Format | Size | vs. Master |
|---|---|---|
Keras .keras (master) |
2,070 MB | baseline |
| ONNX FP32 (edge node) | 237 MB | ~88% smaller |
The 88% reduction was achieved through FP32 ONNX export via tf2onnx (opset 17), with mixed-precision FP16 tensors cast to FP32 for runtime compatibility on CPU execution providers.
Files in This Repository
| File | Size | Description |
|---|---|---|
Final_CNN_Transformer.keras |
2,070 MB | Master model (mixed precision, full TF/Keras stack) |
human_eye_fp32.onnx |
237 MB | Edge node (FP32 ONNX, CPU-deployable, ~88% smaller) |
Final_XGBoost_Hybrid.json |
2.4 MB | XGBoost classification head |
ood_train_mean.npy |
0.6 KB | Per-class training distribution means |
ood_cov_inv.npy |
524 KB | Inverse covariance matrix for Mahalanobis OOD |
ood_threshold.npy |
0.1 KB | 97th percentile OOD decision threshold |
temperature.npy |
0.1 KB | Learned temperature scalar (T β 1.05) |
data_splits/train_indices.npy |
490 KB | Training split indices (seed-fixed) |
data_splits/val_indices.npy |
61 KB | Validation split indices |
data_splits/test_indices.npy |
61 KB | Test split indices |
Usage
Option A β Edge Inference (ONNX, CPU-only, no GPU needed)
import onnxruntime as ort
import numpy as np
import cv2
# Load the 237 MB edge model
sess = ort.InferenceSession("human_eye_fp32.onnx",
providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
CLASSES = ["CNV", "DME", "DRUSEN", "NORMAL"]
def predict(img_path):
img = cv2.imread(img_path)
img = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (224, 224))
img = (img.astype(np.float32) / 255.0
- [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = np.expand_dims(img, axis=0)
logits = sess.run(None, {input_name: img})[0]
return CLASSES[np.argmax(logits[0])], float(np.max(logits[0]))
pred, conf = predict("scan.jpg")
print(f"{pred} ({conf*100:.2f}%)")
Option B β Download and run locally
from huggingface_hub import hf_hub_download
import os
os.makedirs("models", exist_ok=True)
files = ["Final_CNN_Transformer.keras", "human_eye_fp32.onnx",
"Final_XGBoost_Hybrid.json", "ood_train_mean.npy",
"ood_cov_inv.npy", "ood_threshold.npy", "temperature.npy"]
for f in files:
hf_hub_download(repo_id="animeshakr/oct-retinal-weights",
filename=f, local_dir="models/")
Architecture Details
Phase A β CNN + Transformer backbone
EfficientNetV2-Large (pretrained ImageNet-21k) extracts spatial features. Blocks 1β5 are frozen; Block 6+ is fine-tuned during Phase B. The 7Γ7Γ1280 bottleneck is reshaped into 49 patch tokens, projected to 256 dimensions via a trainable linear layer, and augmented with learnable positional encodings. Four Multi-Head Attention blocks (16 heads, key_dim=16) capture long-range retinal layer dependencies. GlobalAveragePooling1D produces the 256-dimensional embedding z β βΒ²β΅βΆ. Hyperparameters were selected with Optuna TPE (10 trials): Ξ·=1.59Γ10β»β΄, dropout=0.38, Focal Loss Ξ³=1.36.
Phase B β XGBoost head
Frozen embeddings from all training samples are used to fit an XGBoost classifier (Optuna: n_estimators=300, max_depth=4, Ξ·=0.1, subsample=0.8). The gradient boosting head sharpens minority-class decision boundaries without synthetic oversampling.
Limitations
- Single-scanner scope: Validated on the Kermany dataset collected on one OCT device type. Cross-scanner generalisation (e.g., Topcon vs. Zeiss Cirrus) is unvalidated and is the next planned benchmark.
- Clinical use: This is a decision-support tool. It should not replace professional ophthalmological assessment, particularly for edge cases like the NORMAL/DRUSEN boundary.
Citation
@article{kumar2026oct,
author = {Kumar, Animesh A.},
title = {A Hybrid {CNN}-Transformer Framework for Retinal {OCT}
Classification with Integrated Clinical Safety Mechanisms},
year = {2026},
doi = {10.5281/zenodo.19224303},
note = {Zenodo software archive, v1.0.0}
}
References
- Kermany et al. (2018). Identifying Medical Diagnoses by Image-Based Deep Learning. Cell, 172(5), 1122β1131. https://doi.org/10.1016/j.cell.2018.02.010
- Chen & Guestrin (2016). XGBoost: A Scalable Tree Boosting System. KDD. https://doi.org/10.1145/2939672.2939785
- Guo et al. (2017). On Calibration of Modern Neural Networks. ICML, 1321β1330.
- Lee et al. (2018). A Simple Unified Framework for Detecting OOD Samples and Adversarial Attacks. NeurIPS, 31.
- He et al. (2019). On OCT Image Classification via Deep Learning. IEEE Photonics Journal, 11(5). https://doi.org/10.1109/JPHOT.2019.2934484
- Li et al. (2021). Applications of Deep Learning in Fundus Images: A Review. Medical Image Analysis, 69, 101971.
Developed by Animesh Kumar β Newcastle University
- Downloads last month
- 250