AlloGen / code /trainers /trainer.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Trainer for the Q_theta state-selectivity scorer.
Implements two-phase training:
Phase 1: DockQ regression (learn complex quality from all data)
Phase 2: Selectivity fine-tuning (learn to rank X+ > X- for the same binder)
Integrates with Weights & Biases for experiment tracking.
"""
import os
import time
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score
import wandb
logger = logging.getLogger(__name__)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class AlloDesignerTrainer:
"""
Two-phase trainer for Q_theta.
Phase 1 (DockQ regression):
- Minimizes MSE(Q_theta(X, Y), DockQ_label) on all complex types
- Learns general complex quality
Phase 2 (Selectivity fine-tuning):
- Minimizes selectivity margin loss on paired (pos, neg) data
- Learns to rank Q(X+, Y) > Q(X-, Y)
- Combined: L = L_regression + lambda_rank * L_selectivity
"""
def __init__(self, model, config, device='cuda'):
self.model = model.to(device)
self.config = config
self.device = device
self.use_sam = config.get('optimizer', 'adamw') == 'sam'
# Optimizer
if self.use_sam:
from utils.sam import SAM
self.optimizer = SAM(
model.parameters(),
base_optimizer=AdamW,
rho=0.05,
lr=config.get('lr', 1e-4),
weight_decay=config.get('weight_decay', 1e-4),
betas=(0.9, 0.999),
)
# SAM wraps AdamW; scheduler goes on base_optimizer
sched_optimizer = self.optimizer.base_optimizer
else:
self.optimizer = AdamW(
model.parameters(),
lr=config.get('lr', 1e-4),
weight_decay=config.get('weight_decay', 1e-4),
betas=(0.9, 0.999),
)
sched_optimizer = self.optimizer
# Learning rate scheduler (warmup + cosine)
n_warmup = config.get('warmup_steps', 100)
n_total = config.get('max_steps', 5000)
warmup_sched = LinearLR(sched_optimizer, start_factor=0.01, end_factor=1.0, total_iters=n_warmup)
cosine_sched = CosineAnnealingLR(sched_optimizer, T_max=n_total - n_warmup, eta_min=1e-6)
self.scheduler = SequentialLR(sched_optimizer, [warmup_sched, cosine_sched], milestones=[n_warmup])
self.global_step = 0
self.best_val_metric = -float('inf')
self.checkpoint_dir = config.get('checkpoint_dir', 'results/checkpoints')
os.makedirs(self.checkpoint_dir, exist_ok=True)
# ------------------------------------------------------------------ #
# Phase 1: DockQ regression
# ------------------------------------------------------------------ #
def train_step_phase1(self, batch):
"""Single training step for Phase 1 (DockQ regression)."""
self.model.train()
node_feats = batch['node_feats'].to(self.device) # [B, N, node_dim]
edge_feats = batch['edge_feats'].to(self.device) # [B, N, N, edge_dim]
node_mask = batch['node_mask'].to(self.device) # [B, N]
labels = batch['label'].to(self.device) # [B]
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
self.optimizer.zero_grad()
scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) # [B]
loss = nn.functional.mse_loss(scores, labels)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
if self.use_sam:
self.optimizer.first_step()
# Second forward-backward pass
scores2 = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
loss2 = nn.functional.mse_loss(scores2, labels)
self.optimizer.zero_grad()
loss2.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.second_step()
else:
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
return {'loss': loss.item(), 'scores': scores.detach(), 'labels': labels}
def run_phase1(self, train_loader, val_loader, n_epochs: int = 30, run_name: str = 'phase1'):
"""Phase 1 training loop."""
logger.info(f"Starting Phase 1 (DockQ regression) for {n_epochs} epochs")
wandb.define_metric('phase1/step')
wandb.define_metric('phase1/*', step_metric='phase1/step')
for epoch in range(n_epochs):
# Train
train_meter = AverageMeter()
all_scores, all_labels = [], []
for batch in train_loader:
result = self.train_step_phase1(batch)
train_meter.update(result['loss'], n=len(result['scores']))
all_scores.append(result['scores'].cpu().numpy())
all_labels.append(result['labels'].cpu().numpy())
if self.global_step % 50 == 0:
wandb.log({
'phase1/train_loss': result['loss'],
'phase1/lr': self.optimizer.param_groups[0]['lr'],
'phase1/step': self.global_step,
})
# Compute Spearman corr on training data
all_scores = np.concatenate(all_scores)
all_labels = np.concatenate(all_labels)
train_spearman = spearmanr(all_scores, all_labels).correlation
# Validate
val_metrics = self.evaluate_phase1(val_loader)
logger.info(
f"Phase1 Epoch {epoch+1}/{n_epochs} | "
f"Train Loss: {train_meter.avg:.4f} | "
f"Train Spearman: {train_spearman:.3f} | "
f"Val Loss: {val_metrics['val_loss']:.4f} | "
f"Val Spearman: {val_metrics['val_spearman']:.3f} | "
f"Val AUC: {val_metrics.get('val_auc', 0):.3f}"
)
wandb.log({
'phase1/epoch': epoch + 1,
'phase1/train_loss_epoch': train_meter.avg,
'phase1/train_spearman': train_spearman,
**{f'phase1/{k}': v for k, v in val_metrics.items()},
})
# Checkpoint best model
if val_metrics['val_spearman'] > self.best_val_metric:
self.best_val_metric = val_metrics['val_spearman']
self.save_checkpoint('best_phase1.pt', extra={'epoch': epoch, 'phase': 1})
logger.info(f" -> New best Phase 1 model (val_spearman={self.best_val_metric:.3f})")
logger.info("Phase 1 training complete.")
@torch.no_grad()
def evaluate_phase1(self, loader):
"""Evaluate Phase 1 model on val/test set."""
self.model.eval()
all_scores, all_labels = [], []
total_loss = 0.0
n_batches = 0
for batch in loader:
node_feats = batch['node_feats'].to(self.device)
edge_feats = batch['edge_feats'].to(self.device)
node_mask = batch['node_mask'].to(self.device)
labels = batch['label'].to(self.device)
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
loss = nn.functional.mse_loss(scores, labels)
total_loss += loss.item()
n_batches += 1
all_scores.append(scores.cpu().numpy())
all_labels.append(labels.cpu().numpy())
all_scores = np.concatenate(all_scores)
all_labels = np.concatenate(all_labels)
spearman = spearmanr(all_scores, all_labels).correlation
if np.isnan(spearman):
spearman = 0.0
metrics = {
'val_loss': total_loss / max(n_batches, 1),
'val_spearman': float(spearman),
}
# AUC for binary quality (label > 0.5 = positive)
binary_labels = (all_labels > 0.5).astype(int)
if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels):
try:
metrics['val_auc'] = roc_auc_score(binary_labels, all_scores)
except Exception:
pass
return metrics
# ------------------------------------------------------------------ #
# Phase 2: Selectivity fine-tuning
# ------------------------------------------------------------------ #
def train_step_phase2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
lambda_ddg: float = 0.1):
"""Single training step for Phase 2 (selectivity margin + ddG auxiliary)."""
self.model.train()
pos = batch['pos']
neg = batch['neg']
pos_node = pos['node_feats'].to(self.device)
pos_edge = pos['edge_feats'].to(self.device)
pos_mask = pos['node_mask'].to(self.device)
pos_label = pos['label'].to(self.device)
pos_ce = pos.get('contact_energy', None)
if pos_ce is not None:
pos_ce = pos_ce.to(self.device)
neg_node = neg['node_feats'].to(self.device)
neg_edge = neg['edge_feats'].to(self.device)
neg_mask = neg['node_mask'].to(self.device)
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
self.optimizer.zero_grad()
pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) # [B]
neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) # [B]
# Regression loss on positive examples
loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
# Selectivity margin loss: pos_score - neg_score > margin
loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
# InfoNCE-style selectivity loss
eps = 1e-6
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
infonce_loss = -(pos_logit - log_denom).mean()
# ddG auxiliary loss: MSE against contact-energy proxy (physics-informed soft label)
loss_ddg = torch.tensor(0.0, device=self.device)
if pos_ce is not None and pos_ce.shape[0] > 0:
# pos_ce is a contact-energy-based ddG proxy in [0, 1]
# Align positive score toward the contact energy signal
loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
loss = loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
if self.use_sam:
self.optimizer.first_step()
# Second forward-backward for SAM
pos_scores2 = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
neg_scores2 = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
loss_reg2 = nn.functional.mse_loss(pos_scores2, pos_label)
loss_margin2 = nn.functional.relu(margin - (pos_scores2 - neg_scores2)).mean()
eps2 = 1e-6
pl2 = torch.log(pos_scores2.clamp(eps2, 1-eps2) / (1-pos_scores2).clamp(eps2))
nl2 = torch.log(neg_scores2.clamp(eps2, 1-eps2) / (1-neg_scores2).clamp(eps2))
ld2 = torch.stack([pl2, nl2], dim=-1).logsumexp(dim=-1)
infonce2 = -(pl2 - ld2).mean()
loss2 = loss_reg2 + lambda_rank * (loss_margin2 + infonce2)
self.optimizer.zero_grad()
loss2.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.second_step()
else:
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
selectivity_gap = (pos_scores - neg_scores).mean().item()
return {
'loss': loss.item(),
'loss_reg': loss_reg.item(),
'loss_margin': loss_margin.item(),
'loss_infonce': infonce_loss.item(),
'loss_ddg': loss_ddg.item(),
'selectivity_gap': selectivity_gap,
'pos_scores': pos_scores.detach(),
'neg_scores': neg_scores.detach(),
}
def train_step_phase2_v2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
lambda_ddg: float = 0.0, lambda_path: float = 0.5):
"""Phase 2 training step with multi-negative + path monotonicity."""
self.model.train()
pos = batch['pos']
neg = batch['neg']
pos_node = pos['node_feats'].to(self.device)
pos_edge = pos['edge_feats'].to(self.device)
pos_mask = pos['node_mask'].to(self.device)
pos_label = pos['label'].to(self.device)
pos_ce = pos.get('contact_energy', None)
if pos_ce is not None:
pos_ce = pos_ce.to(self.device)
neg_node = neg['node_feats'].to(self.device)
neg_edge = neg['edge_feats'].to(self.device)
neg_mask = neg['node_mask'].to(self.device)
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
self.optimizer.zero_grad()
pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
# Score path frames if present
path_scores = []
path_taus = batch.get('path_taus', [])
for path_frame in batch.get('path', []):
p_node = path_frame['node_feats'].to(self.device)
p_edge = path_frame['edge_feats'].to(self.device)
p_mask = path_frame['node_mask'].to(self.device)
p_score = self.model(p_node, p_edge, p_mask)
path_scores.append(p_score)
# Regression loss on positive examples
loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
# Selectivity margin loss
loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
# InfoNCE-style selectivity loss
eps = 1e-6
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
infonce_loss = -(pos_logit - log_denom).mean()
# ddG auxiliary loss
loss_ddg = torch.tensor(0.0, device=self.device)
if pos_ce is not None and pos_ce.shape[0] > 0 and lambda_ddg > 0:
loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
# Path monotonicity loss
loss_path = torch.tensor(0.0, device=self.device)
if path_scores and lambda_path > 0:
small_margin = 0.05
for i in range(len(path_scores) - 1):
loss_path = loss_path + nn.functional.relu(
path_scores[i] - path_scores[i + 1] + small_margin
).mean()
# Last path frame < positive score
loss_path = loss_path + nn.functional.relu(
path_scores[-1] - pos_scores + margin
).mean()
# First path frame > negative score
loss_path = loss_path + nn.functional.relu(
neg_scores - path_scores[0] + small_margin
).mean()
loss = (loss_reg + lambda_rank * (loss_margin + infonce_loss)
+ lambda_ddg * loss_ddg + lambda_path * loss_path)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
selectivity_gap = (pos_scores - neg_scores).mean().item()
return {
'loss': loss.item(),
'loss_reg': loss_reg.item(),
'loss_margin': loss_margin.item(),
'loss_infonce': infonce_loss.item(),
'loss_ddg': loss_ddg.item(),
'loss_path': loss_path.item(),
'selectivity_gap': selectivity_gap,
'pos_scores': pos_scores.detach(),
'neg_scores': neg_scores.detach(),
}
def run_phase2_path(self, train_loader, val_loader, n_epochs: int = 20,
lambda_rank: float = 1.0, margin: float = 0.2,
lambda_ddg: float = 0.0, lambda_path: float = 0.5):
"""Phase 2 with path-aware training loop."""
logger.info(f"Starting Phase 2 (path-aware) for {n_epochs} epochs "
f"[lambda_rank={lambda_rank}, lambda_path={lambda_path}]")
self.best_val_metric = -float('inf')
for epoch in range(n_epochs):
loss_meter = AverageMeter()
gap_meter = AverageMeter()
path_meter = AverageMeter()
for batch in train_loader:
result = self.train_step_phase2_v2(
batch, lambda_rank, margin, lambda_ddg, lambda_path)
B = len(result['pos_scores'])
loss_meter.update(result['loss'], B)
gap_meter.update(result['selectivity_gap'], B)
path_meter.update(result['loss_path'], B)
if self.global_step % 50 == 0:
wandb.log({
'phase2/train_loss': result['loss'],
'phase2/loss_margin': result['loss_margin'],
'phase2/loss_infonce': result['loss_infonce'],
'phase2/loss_path': result['loss_path'],
'phase2/selectivity_gap': result['selectivity_gap'],
'phase2/lr': self.optimizer.param_groups[0]['lr'],
'phase2/step': self.global_step,
})
val_metrics = self.evaluate_phase2(val_loader)
logger.info(
f"Phase2-Path Epoch {epoch+1}/{n_epochs} | "
f"Loss: {loss_meter.avg:.4f} | "
f"Gap: {gap_meter.avg:.3f} | "
f"Path: {path_meter.avg:.4f} | "
f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
)
wandb.log({
'phase2/epoch': epoch + 1,
'phase2/train_loss_epoch': loss_meter.avg,
'phase2/train_gap_epoch': gap_meter.avg,
'phase2/train_path_loss_epoch': path_meter.avg,
**{f'phase2/{k}': v for k, v in val_metrics.items()},
})
if val_metrics['val_selectivity_gap'] > self.best_val_metric:
self.best_val_metric = val_metrics['val_selectivity_gap']
self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
logger.info("Phase 2 (path-aware) training complete.")
def run_phase2(self, train_loader, val_loader, n_epochs: int = 20,
lambda_rank: float = 1.0, margin: float = 0.2,
lambda_ddg: float = 0.1):
"""Phase 2 training loop (selectivity fine-tuning + ddG auxiliary)."""
logger.info(f"Starting Phase 2 (selectivity fine-tuning) for {n_epochs} epochs "
f"[lambda_rank={lambda_rank}, lambda_ddg={lambda_ddg}]")
self.best_val_metric = -float('inf')
for epoch in range(n_epochs):
loss_meter = AverageMeter()
gap_meter = AverageMeter()
for batch in train_loader:
result = self.train_step_phase2(batch, lambda_rank, margin, lambda_ddg)
B = len(result['pos_scores'])
loss_meter.update(result['loss'], B)
gap_meter.update(result['selectivity_gap'], B)
if self.global_step % 50 == 0:
wandb.log({
'phase2/train_loss': result['loss'],
'phase2/loss_margin': result['loss_margin'],
'phase2/loss_infonce': result['loss_infonce'],
'phase2/loss_ddg': result['loss_ddg'],
'phase2/selectivity_gap': result['selectivity_gap'],
'phase2/lr': self.optimizer.param_groups[0]['lr'],
'phase2/step': self.global_step,
})
# Validate
val_metrics = self.evaluate_phase2(val_loader)
logger.info(
f"Phase2 Epoch {epoch+1}/{n_epochs} | "
f"Loss: {loss_meter.avg:.4f} | "
f"Gap: {gap_meter.avg:.3f} | "
f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
)
wandb.log({
'phase2/epoch': epoch + 1,
'phase2/train_loss_epoch': loss_meter.avg,
'phase2/train_gap_epoch': gap_meter.avg,
**{f'phase2/{k}': v for k, v in val_metrics.items()},
})
# Checkpoint
if val_metrics['val_selectivity_gap'] > self.best_val_metric:
self.best_val_metric = val_metrics['val_selectivity_gap']
self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
logger.info("Phase 2 training complete.")
@torch.no_grad()
def evaluate_phase2(self, loader):
"""Evaluate selectivity on paired (pos, neg) val set."""
self.model.eval()
all_pos_scores, all_neg_scores = [], []
for batch in loader:
if 'pos' not in batch:
continue
pos = batch['pos']
neg = batch['neg']
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
pos_scores = self.model(
pos['node_feats'].to(self.device),
pos['edge_feats'].to(self.device),
pos['node_mask'].to(self.device),
esm_feats=pos_esm
)
neg_scores = self.model(
neg['node_feats'].to(self.device),
neg['edge_feats'].to(self.device),
neg['node_mask'].to(self.device),
esm_feats=neg_esm
)
all_pos_scores.append(pos_scores.cpu().numpy())
all_neg_scores.append(neg_scores.cpu().numpy())
if not all_pos_scores:
return {'val_selectivity_gap': 0.0, 'val_ranking_acc': 0.5}
all_pos = np.concatenate(all_pos_scores)
all_neg = np.concatenate(all_neg_scores)
gap = float((all_pos - all_neg).mean())
acc = float((all_pos > all_neg).mean())
return {
'val_selectivity_gap': gap,
'val_ranking_acc': acc,
'val_pos_score_mean': float(all_pos.mean()),
'val_neg_score_mean': float(all_neg.mean()),
}
# ------------------------------------------------------------------ #
# Checkpointing
# ------------------------------------------------------------------ #
def save_checkpoint(self, filename: str, extra: dict = None):
path = os.path.join(self.checkpoint_dir, filename)
state = {
'model_state': self.model.state_dict(),
'optimizer_state': self.optimizer.state_dict(),
'global_step': self.global_step,
'config': self.config,
}
if extra:
state.update(extra)
torch.save(state, path)
logger.debug(f"Saved checkpoint: {path}")
def load_checkpoint(self, filename: str):
path = os.path.join(self.checkpoint_dir, filename)
if not os.path.exists(path):
logger.warning(f"Checkpoint not found: {path}")
return False
state = torch.load(path, map_location=self.device)
self.model.load_state_dict(state['model_state'])
self.optimizer.load_state_dict(state['optimizer_state'])
self.global_step = state.get('global_step', 0)
logger.info(f"Loaded checkpoint from {path} (step {self.global_step})")
return True
# ------------------------------------------------------------------ #
# Full evaluation (test set)
# ------------------------------------------------------------------ #
@torch.no_grad()
def evaluate_test(self, test_loader, phase: int = 2):
"""Full evaluation on test set with all metrics."""
self.model.eval()
all_scores, all_labels, all_types = [], [], []
for batch in test_loader:
if 'pos' in batch:
# Paired batch
for key in ['pos', 'neg']:
d = batch[key]
d_esm = d['esm_feats'].to(self.device) if 'esm_feats' in d else None
scores = self.model(
d['node_feats'].to(self.device),
d['edge_feats'].to(self.device),
d['node_mask'].to(self.device),
esm_feats=d_esm
)
all_scores.extend(scores.cpu().numpy().tolist())
all_labels.extend(d['label'].numpy().tolist())
all_types.extend(['pos' if key == 'pos' else 'neg'] * len(scores))
else:
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
scores = self.model(
batch['node_feats'].to(self.device),
batch['edge_feats'].to(self.device),
batch['node_mask'].to(self.device),
esm_feats=esm_feats
)
all_scores.extend(scores.cpu().numpy().tolist())
all_labels.extend(batch['label'].numpy().tolist())
all_types.extend(batch['type'])
all_scores = np.array(all_scores)
all_labels = np.array(all_labels)
metrics = {}
# Spearman correlation (all samples)
metrics['test_spearman'] = float(spearmanr(all_scores, all_labels).correlation or 0)
# AUC (binary: label > 0.5 = positive quality)
binary = (all_labels > 0.5).astype(int)
if binary.sum() > 0 and binary.sum() < len(binary):
try:
metrics['test_auc'] = float(roc_auc_score(binary, all_scores))
except Exception:
pass
# Selectivity gap (pos vs neg_apo pairs)
pos_mask = np.array([t == 'pos' or t == 'positive' for t in all_types])
neg_mask = np.array([t == 'neg' or t == 'negative_apo' for t in all_types])
if pos_mask.sum() > 0 and neg_mask.sum() > 0:
metrics['test_selectivity_gap'] = float(all_scores[pos_mask].mean() - all_scores[neg_mask].mean())
logger.info(f"Test evaluation: {metrics}")
wandb.log({f'test/{k}': v for k, v in metrics.items()})
return metrics, all_scores, all_labels, all_types