| """ |
| Trainer for the Q_theta state-selectivity scorer. |
| |
| Implements two-phase training: |
| Phase 1: DockQ regression (learn complex quality from all data) |
| Phase 2: Selectivity fine-tuning (learn to rank X+ > X- for the same binder) |
| |
| Integrates with Weights & Biases for experiment tracking. |
| """ |
|
|
| import os |
| import time |
| import logging |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| from scipy.stats import spearmanr |
| from sklearn.metrics import roc_auc_score |
|
|
| import wandb |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AverageMeter: |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0.0 |
| self.avg = 0.0 |
| self.sum = 0.0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| class AlloDesignerTrainer: |
| """ |
| Two-phase trainer for Q_theta. |
| |
| Phase 1 (DockQ regression): |
| - Minimizes MSE(Q_theta(X, Y), DockQ_label) on all complex types |
| - Learns general complex quality |
| |
| Phase 2 (Selectivity fine-tuning): |
| - Minimizes selectivity margin loss on paired (pos, neg) data |
| - Learns to rank Q(X+, Y) > Q(X-, Y) |
| - Combined: L = L_regression + lambda_rank * L_selectivity |
| """ |
|
|
| def __init__(self, model, config, device='cuda'): |
| self.model = model.to(device) |
| self.config = config |
| self.device = device |
| self.use_sam = config.get('optimizer', 'adamw') == 'sam' |
|
|
| |
| if self.use_sam: |
| from utils.sam import SAM |
| self.optimizer = SAM( |
| model.parameters(), |
| base_optimizer=AdamW, |
| rho=0.05, |
| lr=config.get('lr', 1e-4), |
| weight_decay=config.get('weight_decay', 1e-4), |
| betas=(0.9, 0.999), |
| ) |
| |
| sched_optimizer = self.optimizer.base_optimizer |
| else: |
| self.optimizer = AdamW( |
| model.parameters(), |
| lr=config.get('lr', 1e-4), |
| weight_decay=config.get('weight_decay', 1e-4), |
| betas=(0.9, 0.999), |
| ) |
| sched_optimizer = self.optimizer |
|
|
| |
| n_warmup = config.get('warmup_steps', 100) |
| n_total = config.get('max_steps', 5000) |
|
|
| warmup_sched = LinearLR(sched_optimizer, start_factor=0.01, end_factor=1.0, total_iters=n_warmup) |
| cosine_sched = CosineAnnealingLR(sched_optimizer, T_max=n_total - n_warmup, eta_min=1e-6) |
| self.scheduler = SequentialLR(sched_optimizer, [warmup_sched, cosine_sched], milestones=[n_warmup]) |
|
|
| self.global_step = 0 |
| self.best_val_metric = -float('inf') |
| self.checkpoint_dir = config.get('checkpoint_dir', 'results/checkpoints') |
| os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
| |
| |
| |
|
|
| def train_step_phase1(self, batch): |
| """Single training step for Phase 1 (DockQ regression).""" |
| self.model.train() |
| node_feats = batch['node_feats'].to(self.device) |
| edge_feats = batch['edge_feats'].to(self.device) |
| node_mask = batch['node_mask'].to(self.device) |
| labels = batch['label'].to(self.device) |
| esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None |
|
|
| self.optimizer.zero_grad() |
|
|
| scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) |
| loss = nn.functional.mse_loss(scores, labels) |
|
|
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
| if self.use_sam: |
| self.optimizer.first_step() |
| |
| scores2 = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) |
| loss2 = nn.functional.mse_loss(scores2, labels) |
| self.optimizer.zero_grad() |
| loss2.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.second_step() |
| else: |
| self.optimizer.step() |
|
|
| self.scheduler.step() |
| self.global_step += 1 |
|
|
| return {'loss': loss.item(), 'scores': scores.detach(), 'labels': labels} |
|
|
| def run_phase1(self, train_loader, val_loader, n_epochs: int = 30, run_name: str = 'phase1'): |
| """Phase 1 training loop.""" |
| logger.info(f"Starting Phase 1 (DockQ regression) for {n_epochs} epochs") |
| wandb.define_metric('phase1/step') |
| wandb.define_metric('phase1/*', step_metric='phase1/step') |
|
|
| for epoch in range(n_epochs): |
| |
| train_meter = AverageMeter() |
| all_scores, all_labels = [], [] |
|
|
| for batch in train_loader: |
| result = self.train_step_phase1(batch) |
| train_meter.update(result['loss'], n=len(result['scores'])) |
| all_scores.append(result['scores'].cpu().numpy()) |
| all_labels.append(result['labels'].cpu().numpy()) |
|
|
| if self.global_step % 50 == 0: |
| wandb.log({ |
| 'phase1/train_loss': result['loss'], |
| 'phase1/lr': self.optimizer.param_groups[0]['lr'], |
| 'phase1/step': self.global_step, |
| }) |
|
|
| |
| all_scores = np.concatenate(all_scores) |
| all_labels = np.concatenate(all_labels) |
| train_spearman = spearmanr(all_scores, all_labels).correlation |
|
|
| |
| val_metrics = self.evaluate_phase1(val_loader) |
|
|
| logger.info( |
| f"Phase1 Epoch {epoch+1}/{n_epochs} | " |
| f"Train Loss: {train_meter.avg:.4f} | " |
| f"Train Spearman: {train_spearman:.3f} | " |
| f"Val Loss: {val_metrics['val_loss']:.4f} | " |
| f"Val Spearman: {val_metrics['val_spearman']:.3f} | " |
| f"Val AUC: {val_metrics.get('val_auc', 0):.3f}" |
| ) |
|
|
| wandb.log({ |
| 'phase1/epoch': epoch + 1, |
| 'phase1/train_loss_epoch': train_meter.avg, |
| 'phase1/train_spearman': train_spearman, |
| **{f'phase1/{k}': v for k, v in val_metrics.items()}, |
| }) |
|
|
| |
| if val_metrics['val_spearman'] > self.best_val_metric: |
| self.best_val_metric = val_metrics['val_spearman'] |
| self.save_checkpoint('best_phase1.pt', extra={'epoch': epoch, 'phase': 1}) |
| logger.info(f" -> New best Phase 1 model (val_spearman={self.best_val_metric:.3f})") |
|
|
| logger.info("Phase 1 training complete.") |
|
|
| @torch.no_grad() |
| def evaluate_phase1(self, loader): |
| """Evaluate Phase 1 model on val/test set.""" |
| self.model.eval() |
| all_scores, all_labels = [], [] |
| total_loss = 0.0 |
| n_batches = 0 |
|
|
| for batch in loader: |
| node_feats = batch['node_feats'].to(self.device) |
| edge_feats = batch['edge_feats'].to(self.device) |
| node_mask = batch['node_mask'].to(self.device) |
| labels = batch['label'].to(self.device) |
| esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None |
|
|
| scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) |
| loss = nn.functional.mse_loss(scores, labels) |
|
|
| total_loss += loss.item() |
| n_batches += 1 |
| all_scores.append(scores.cpu().numpy()) |
| all_labels.append(labels.cpu().numpy()) |
|
|
| all_scores = np.concatenate(all_scores) |
| all_labels = np.concatenate(all_labels) |
|
|
| spearman = spearmanr(all_scores, all_labels).correlation |
| if np.isnan(spearman): |
| spearman = 0.0 |
|
|
| metrics = { |
| 'val_loss': total_loss / max(n_batches, 1), |
| 'val_spearman': float(spearman), |
| } |
|
|
| |
| binary_labels = (all_labels > 0.5).astype(int) |
| if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels): |
| try: |
| metrics['val_auc'] = roc_auc_score(binary_labels, all_scores) |
| except Exception: |
| pass |
|
|
| return metrics |
|
|
| |
| |
| |
|
|
| def train_step_phase2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2, |
| lambda_ddg: float = 0.1): |
| """Single training step for Phase 2 (selectivity margin + ddG auxiliary).""" |
| self.model.train() |
|
|
| pos = batch['pos'] |
| neg = batch['neg'] |
|
|
| pos_node = pos['node_feats'].to(self.device) |
| pos_edge = pos['edge_feats'].to(self.device) |
| pos_mask = pos['node_mask'].to(self.device) |
| pos_label = pos['label'].to(self.device) |
| pos_ce = pos.get('contact_energy', None) |
| if pos_ce is not None: |
| pos_ce = pos_ce.to(self.device) |
|
|
| neg_node = neg['node_feats'].to(self.device) |
| neg_edge = neg['edge_feats'].to(self.device) |
| neg_mask = neg['node_mask'].to(self.device) |
| pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None |
| neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None |
|
|
| self.optimizer.zero_grad() |
|
|
| pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) |
| neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) |
|
|
| |
| loss_reg = nn.functional.mse_loss(pos_scores, pos_label) |
|
|
| |
| loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean() |
|
|
| |
| eps = 1e-6 |
| pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) |
| neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps)) |
| log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1) |
| infonce_loss = -(pos_logit - log_denom).mean() |
|
|
| |
| loss_ddg = torch.tensor(0.0, device=self.device) |
| if pos_ce is not None and pos_ce.shape[0] > 0: |
| |
| |
| loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce) |
|
|
| loss = loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg |
|
|
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
| if self.use_sam: |
| self.optimizer.first_step() |
| |
| pos_scores2 = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) |
| neg_scores2 = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) |
| loss_reg2 = nn.functional.mse_loss(pos_scores2, pos_label) |
| loss_margin2 = nn.functional.relu(margin - (pos_scores2 - neg_scores2)).mean() |
| eps2 = 1e-6 |
| pl2 = torch.log(pos_scores2.clamp(eps2, 1-eps2) / (1-pos_scores2).clamp(eps2)) |
| nl2 = torch.log(neg_scores2.clamp(eps2, 1-eps2) / (1-neg_scores2).clamp(eps2)) |
| ld2 = torch.stack([pl2, nl2], dim=-1).logsumexp(dim=-1) |
| infonce2 = -(pl2 - ld2).mean() |
| loss2 = loss_reg2 + lambda_rank * (loss_margin2 + infonce2) |
| self.optimizer.zero_grad() |
| loss2.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.second_step() |
| else: |
| self.optimizer.step() |
|
|
| self.scheduler.step() |
| self.global_step += 1 |
|
|
| selectivity_gap = (pos_scores - neg_scores).mean().item() |
|
|
| return { |
| 'loss': loss.item(), |
| 'loss_reg': loss_reg.item(), |
| 'loss_margin': loss_margin.item(), |
| 'loss_infonce': infonce_loss.item(), |
| 'loss_ddg': loss_ddg.item(), |
| 'selectivity_gap': selectivity_gap, |
| 'pos_scores': pos_scores.detach(), |
| 'neg_scores': neg_scores.detach(), |
| } |
|
|
| def train_step_phase2_v2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2, |
| lambda_ddg: float = 0.0, lambda_path: float = 0.5): |
| """Phase 2 training step with multi-negative + path monotonicity.""" |
| self.model.train() |
|
|
| pos = batch['pos'] |
| neg = batch['neg'] |
|
|
| pos_node = pos['node_feats'].to(self.device) |
| pos_edge = pos['edge_feats'].to(self.device) |
| pos_mask = pos['node_mask'].to(self.device) |
| pos_label = pos['label'].to(self.device) |
| pos_ce = pos.get('contact_energy', None) |
| if pos_ce is not None: |
| pos_ce = pos_ce.to(self.device) |
|
|
| neg_node = neg['node_feats'].to(self.device) |
| neg_edge = neg['edge_feats'].to(self.device) |
| neg_mask = neg['node_mask'].to(self.device) |
| pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None |
| neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None |
|
|
| self.optimizer.zero_grad() |
|
|
| pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) |
| neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) |
|
|
| |
| path_scores = [] |
| path_taus = batch.get('path_taus', []) |
| for path_frame in batch.get('path', []): |
| p_node = path_frame['node_feats'].to(self.device) |
| p_edge = path_frame['edge_feats'].to(self.device) |
| p_mask = path_frame['node_mask'].to(self.device) |
| p_score = self.model(p_node, p_edge, p_mask) |
| path_scores.append(p_score) |
|
|
| |
| loss_reg = nn.functional.mse_loss(pos_scores, pos_label) |
|
|
| |
| loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean() |
|
|
| |
| eps = 1e-6 |
| pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) |
| neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps)) |
| log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1) |
| infonce_loss = -(pos_logit - log_denom).mean() |
|
|
| |
| loss_ddg = torch.tensor(0.0, device=self.device) |
| if pos_ce is not None and pos_ce.shape[0] > 0 and lambda_ddg > 0: |
| loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce) |
|
|
| |
| loss_path = torch.tensor(0.0, device=self.device) |
| if path_scores and lambda_path > 0: |
| small_margin = 0.05 |
| for i in range(len(path_scores) - 1): |
| loss_path = loss_path + nn.functional.relu( |
| path_scores[i] - path_scores[i + 1] + small_margin |
| ).mean() |
| |
| loss_path = loss_path + nn.functional.relu( |
| path_scores[-1] - pos_scores + margin |
| ).mean() |
| |
| loss_path = loss_path + nn.functional.relu( |
| neg_scores - path_scores[0] + small_margin |
| ).mean() |
|
|
| loss = (loss_reg + lambda_rank * (loss_margin + infonce_loss) |
| + lambda_ddg * loss_ddg + lambda_path * loss_path) |
|
|
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
| self.scheduler.step() |
| self.global_step += 1 |
|
|
| selectivity_gap = (pos_scores - neg_scores).mean().item() |
|
|
| return { |
| 'loss': loss.item(), |
| 'loss_reg': loss_reg.item(), |
| 'loss_margin': loss_margin.item(), |
| 'loss_infonce': infonce_loss.item(), |
| 'loss_ddg': loss_ddg.item(), |
| 'loss_path': loss_path.item(), |
| 'selectivity_gap': selectivity_gap, |
| 'pos_scores': pos_scores.detach(), |
| 'neg_scores': neg_scores.detach(), |
| } |
|
|
| def run_phase2_path(self, train_loader, val_loader, n_epochs: int = 20, |
| lambda_rank: float = 1.0, margin: float = 0.2, |
| lambda_ddg: float = 0.0, lambda_path: float = 0.5): |
| """Phase 2 with path-aware training loop.""" |
| logger.info(f"Starting Phase 2 (path-aware) for {n_epochs} epochs " |
| f"[lambda_rank={lambda_rank}, lambda_path={lambda_path}]") |
| self.best_val_metric = -float('inf') |
|
|
| for epoch in range(n_epochs): |
| loss_meter = AverageMeter() |
| gap_meter = AverageMeter() |
| path_meter = AverageMeter() |
|
|
| for batch in train_loader: |
| result = self.train_step_phase2_v2( |
| batch, lambda_rank, margin, lambda_ddg, lambda_path) |
| B = len(result['pos_scores']) |
| loss_meter.update(result['loss'], B) |
| gap_meter.update(result['selectivity_gap'], B) |
| path_meter.update(result['loss_path'], B) |
|
|
| if self.global_step % 50 == 0: |
| wandb.log({ |
| 'phase2/train_loss': result['loss'], |
| 'phase2/loss_margin': result['loss_margin'], |
| 'phase2/loss_infonce': result['loss_infonce'], |
| 'phase2/loss_path': result['loss_path'], |
| 'phase2/selectivity_gap': result['selectivity_gap'], |
| 'phase2/lr': self.optimizer.param_groups[0]['lr'], |
| 'phase2/step': self.global_step, |
| }) |
|
|
| val_metrics = self.evaluate_phase2(val_loader) |
|
|
| logger.info( |
| f"Phase2-Path Epoch {epoch+1}/{n_epochs} | " |
| f"Loss: {loss_meter.avg:.4f} | " |
| f"Gap: {gap_meter.avg:.3f} | " |
| f"Path: {path_meter.avg:.4f} | " |
| f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | " |
| f"Val Acc: {val_metrics['val_ranking_acc']:.3f}" |
| ) |
|
|
| wandb.log({ |
| 'phase2/epoch': epoch + 1, |
| 'phase2/train_loss_epoch': loss_meter.avg, |
| 'phase2/train_gap_epoch': gap_meter.avg, |
| 'phase2/train_path_loss_epoch': path_meter.avg, |
| **{f'phase2/{k}': v for k, v in val_metrics.items()}, |
| }) |
|
|
| if val_metrics['val_selectivity_gap'] > self.best_val_metric: |
| self.best_val_metric = val_metrics['val_selectivity_gap'] |
| self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2}) |
| logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})") |
|
|
| logger.info("Phase 2 (path-aware) training complete.") |
|
|
| def run_phase2(self, train_loader, val_loader, n_epochs: int = 20, |
| lambda_rank: float = 1.0, margin: float = 0.2, |
| lambda_ddg: float = 0.1): |
| """Phase 2 training loop (selectivity fine-tuning + ddG auxiliary).""" |
| logger.info(f"Starting Phase 2 (selectivity fine-tuning) for {n_epochs} epochs " |
| f"[lambda_rank={lambda_rank}, lambda_ddg={lambda_ddg}]") |
| self.best_val_metric = -float('inf') |
|
|
| for epoch in range(n_epochs): |
| loss_meter = AverageMeter() |
| gap_meter = AverageMeter() |
|
|
| for batch in train_loader: |
| result = self.train_step_phase2(batch, lambda_rank, margin, lambda_ddg) |
| B = len(result['pos_scores']) |
| loss_meter.update(result['loss'], B) |
| gap_meter.update(result['selectivity_gap'], B) |
|
|
| if self.global_step % 50 == 0: |
| wandb.log({ |
| 'phase2/train_loss': result['loss'], |
| 'phase2/loss_margin': result['loss_margin'], |
| 'phase2/loss_infonce': result['loss_infonce'], |
| 'phase2/loss_ddg': result['loss_ddg'], |
| 'phase2/selectivity_gap': result['selectivity_gap'], |
| 'phase2/lr': self.optimizer.param_groups[0]['lr'], |
| 'phase2/step': self.global_step, |
| }) |
|
|
| |
| val_metrics = self.evaluate_phase2(val_loader) |
|
|
| logger.info( |
| f"Phase2 Epoch {epoch+1}/{n_epochs} | " |
| f"Loss: {loss_meter.avg:.4f} | " |
| f"Gap: {gap_meter.avg:.3f} | " |
| f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | " |
| f"Val Acc: {val_metrics['val_ranking_acc']:.3f}" |
| ) |
|
|
| wandb.log({ |
| 'phase2/epoch': epoch + 1, |
| 'phase2/train_loss_epoch': loss_meter.avg, |
| 'phase2/train_gap_epoch': gap_meter.avg, |
| **{f'phase2/{k}': v for k, v in val_metrics.items()}, |
| }) |
|
|
| |
| if val_metrics['val_selectivity_gap'] > self.best_val_metric: |
| self.best_val_metric = val_metrics['val_selectivity_gap'] |
| self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2}) |
| logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})") |
|
|
| logger.info("Phase 2 training complete.") |
|
|
| @torch.no_grad() |
| def evaluate_phase2(self, loader): |
| """Evaluate selectivity on paired (pos, neg) val set.""" |
| self.model.eval() |
| all_pos_scores, all_neg_scores = [], [] |
|
|
| for batch in loader: |
| if 'pos' not in batch: |
| continue |
| pos = batch['pos'] |
| neg = batch['neg'] |
|
|
| pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None |
| neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None |
| pos_scores = self.model( |
| pos['node_feats'].to(self.device), |
| pos['edge_feats'].to(self.device), |
| pos['node_mask'].to(self.device), |
| esm_feats=pos_esm |
| ) |
| neg_scores = self.model( |
| neg['node_feats'].to(self.device), |
| neg['edge_feats'].to(self.device), |
| neg['node_mask'].to(self.device), |
| esm_feats=neg_esm |
| ) |
| all_pos_scores.append(pos_scores.cpu().numpy()) |
| all_neg_scores.append(neg_scores.cpu().numpy()) |
|
|
| if not all_pos_scores: |
| return {'val_selectivity_gap': 0.0, 'val_ranking_acc': 0.5} |
|
|
| all_pos = np.concatenate(all_pos_scores) |
| all_neg = np.concatenate(all_neg_scores) |
|
|
| gap = float((all_pos - all_neg).mean()) |
| acc = float((all_pos > all_neg).mean()) |
|
|
| return { |
| 'val_selectivity_gap': gap, |
| 'val_ranking_acc': acc, |
| 'val_pos_score_mean': float(all_pos.mean()), |
| 'val_neg_score_mean': float(all_neg.mean()), |
| } |
|
|
| |
| |
| |
|
|
| def save_checkpoint(self, filename: str, extra: dict = None): |
| path = os.path.join(self.checkpoint_dir, filename) |
| state = { |
| 'model_state': self.model.state_dict(), |
| 'optimizer_state': self.optimizer.state_dict(), |
| 'global_step': self.global_step, |
| 'config': self.config, |
| } |
| if extra: |
| state.update(extra) |
| torch.save(state, path) |
| logger.debug(f"Saved checkpoint: {path}") |
|
|
| def load_checkpoint(self, filename: str): |
| path = os.path.join(self.checkpoint_dir, filename) |
| if not os.path.exists(path): |
| logger.warning(f"Checkpoint not found: {path}") |
| return False |
| state = torch.load(path, map_location=self.device) |
| self.model.load_state_dict(state['model_state']) |
| self.optimizer.load_state_dict(state['optimizer_state']) |
| self.global_step = state.get('global_step', 0) |
| logger.info(f"Loaded checkpoint from {path} (step {self.global_step})") |
| return True |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_test(self, test_loader, phase: int = 2): |
| """Full evaluation on test set with all metrics.""" |
| self.model.eval() |
| all_scores, all_labels, all_types = [], [], [] |
|
|
| for batch in test_loader: |
| if 'pos' in batch: |
| |
| for key in ['pos', 'neg']: |
| d = batch[key] |
| d_esm = d['esm_feats'].to(self.device) if 'esm_feats' in d else None |
| scores = self.model( |
| d['node_feats'].to(self.device), |
| d['edge_feats'].to(self.device), |
| d['node_mask'].to(self.device), |
| esm_feats=d_esm |
| ) |
| all_scores.extend(scores.cpu().numpy().tolist()) |
| all_labels.extend(d['label'].numpy().tolist()) |
| all_types.extend(['pos' if key == 'pos' else 'neg'] * len(scores)) |
| else: |
| esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None |
| scores = self.model( |
| batch['node_feats'].to(self.device), |
| batch['edge_feats'].to(self.device), |
| batch['node_mask'].to(self.device), |
| esm_feats=esm_feats |
| ) |
| all_scores.extend(scores.cpu().numpy().tolist()) |
| all_labels.extend(batch['label'].numpy().tolist()) |
| all_types.extend(batch['type']) |
|
|
| all_scores = np.array(all_scores) |
| all_labels = np.array(all_labels) |
|
|
| metrics = {} |
|
|
| |
| metrics['test_spearman'] = float(spearmanr(all_scores, all_labels).correlation or 0) |
|
|
| |
| binary = (all_labels > 0.5).astype(int) |
| if binary.sum() > 0 and binary.sum() < len(binary): |
| try: |
| metrics['test_auc'] = float(roc_auc_score(binary, all_scores)) |
| except Exception: |
| pass |
|
|
| |
| pos_mask = np.array([t == 'pos' or t == 'positive' for t in all_types]) |
| neg_mask = np.array([t == 'neg' or t == 'negative_apo' for t in all_types]) |
| if pos_mask.sum() > 0 and neg_mask.sum() > 0: |
| metrics['test_selectivity_gap'] = float(all_scores[pos_mask].mean() - all_scores[neg_mask].mean()) |
|
|
| logger.info(f"Test evaluation: {metrics}") |
| wandb.log({f'test/{k}': v for k, v in metrics.items()}) |
|
|
| return metrics, all_scores, all_labels, all_types |
|
|