Exoplanet Transit Detector πŸ”­πŸͺ

A multi-branch 1D CNN for detecting exoplanet transits in stellar light curves, based on the AstroNet/ExoMiner++ architecture (NASA Ames).

Model Architecture

Multi-Branch 1D CNN with 5 input branches:

  • Global flux branch: Phase-folded full light curve (201 bins)
  • Local flux branch: Zoomed transit view (81 bins)
  • Odd transit branch: Odd-numbered transits (201 bins)
  • Even transit branch: Even-numbered transits (201 bins)
  • Scalar features: Period, duration, depth, stellar parameters (9 features)

Each flux branch uses 2 convolutional blocks with 3 conv layers each (8β†’16 filters), batch normalization, and max pooling. Branches are fused and fed through a 4-layer fully-connected classifier head.

Total parameters: 244,181

Performance

Metric Test Set
Accuracy 89.10%
F1 (weighted) 89.03%
Precision (weighted) 89.03%
Recall (weighted) 89.10%
Loss 0.2804

Training Details

  • Dataset: bingbangboom/exoplanet-transit-detection
    • Multi-mission: Kepler + TESS + K2
    • 18,853 train / 2,357 val / 2,357 test samples
    • 3-class: PLANET, FALSE_POSITIVE, NO_SIGNAL
  • Optimizer: AdamW (lr=5e-4, cosine schedule, 5% warmup)
  • Loss: Weighted cross-entropy (inverse frequency balancing)
  • Epochs: 30 (best model at epoch 15)
  • Batch size: 128
  • Architecture reference: ExoMiner++ (arxiv:2502.09790) & AstroNet-Triage-v2 (arxiv:2301.01371)

Usage

import torch
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# 1. Download and load model
model_path = hf_hub_download("sarojpatil16/exoplanet-transit-detector", "model.safetensors")

# 2. Recreate architecture (copy from this model card or train.py)
# ... (see model architecture code below) ...
# model = AstroNetCNN(n_scalars=9, num_classes=3)
# model.load_state_dict(load_file(model_path))
# model.eval()

# 3. Prepare inputs from light curve data
# flux_global: (1, 201) - phase-folded full light curve, median-subtracted & MAD-normalized
# flux_local:  (1, 81)  - zoomed transit view
# flux_odd:    (1, 201) - odd-numbered transits
# flux_even:   (1, 201) - even-numbered transits
# scalars:     (1, 9)   - [period_days, duration_hrs, depth_ppm, teff, logg, radius, mass, metallicity, kepmag]
#                          (period, duration, depth are log1p-transformed)

# 4. Predict
# with torch.no_grad():
#     output = model(flux_global, flux_local, flux_odd, flux_even, scalars)
#     probabilities = torch.softmax(output.logits, dim=-1)
#     pred_class = torch.argmax(output.logits, dim=-1)
#     # 0=PLANET, 1=FALSE_POSITIVE, 2=NO_SIGNAL

Label Mapping

Class ID Label Description
0 PLANET Confirmed or candidate exoplanet transit
1 FALSE_POSITIVE Signal is not a planet (eclipsing binary, stellar variability, etc.)
2 NO_SIGNAL No significant transit signal detected

References

Downloads last month
11
Safetensors
Model size
245k params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train sarojpatil16/exoplanet-transit-detector

Papers for sarojpatil16/exoplanet-transit-detector