""" Evaluation script for the trained Q_theta scorer. Computes: 1. Selectivity metrics (gap, ranking accuracy, AUC) 2. DockQ correlation (Spearman/Pearson) 3. Score distributions (violin plots) 4. Best-of-K analysis (as function of K) 5. Per-target breakdown Usage: python code/scripts/evaluate.py \ --target cam \ --checkpoint checkpoints/Q_theta_phase2.pt \ --data_dir data/processed \ --gpu 7 """ import os import sys import argparse import logging import json import numpy as np import torch import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy.stats import spearmanr, pearsonr from sklearn.metrics import roc_auc_score, roc_curve _CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) if _CODE_DIR not in sys.path: sys.path.insert(0, _CODE_DIR) from models.scorer import build_model from data.dataset import TwoStateComplexDataset, collate_fn from torch.utils.data import DataLoader logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) def compute_best_of_k(pos_scores, K_values=None, threshold=0.7): """ Simulate best-of-K selection: what fraction of draws contain at least one good binder? Assumes pos_scores are from a distribution of candidate binders for goal state X+. """ if K_values is None: K_values = [1, 2, 5, 10, 20, 50, 100] results = {} n = len(pos_scores) n_trials = 1000 for K in K_values: successes = 0 for _ in range(n_trials): idxs = np.random.choice(n, size=min(K, n), replace=False) best_score = pos_scores[idxs].max() if best_score >= threshold: successes += 1 results[K] = successes / n_trials return results def compute_selectivity_margin(pos_scores, neg_scores): """Compute per-sample selectivity margin S_theta.""" eps = 1e-6 pos_logit = np.log(pos_scores.clip(eps, 1-eps) / (1-pos_scores).clip(eps)) neg_logit = np.log(neg_scores.clip(eps, 1-eps) / (1-neg_scores).clip(eps)) selectivity = pos_logit - np.log(np.exp(neg_logit) + 1e-8) return selectivity def plot_score_distributions(pos_scores, neg_scores, decoy_scores=None, title='Score Distributions', outpath=None): """Violin plot of score distributions for different complex types.""" fig, ax = plt.subplots(figsize=(8, 6)) data = [pos_scores, neg_scores] labels = ['Positive\n(X+, Y)', 'Negative\n(X0, Y)'] colors = ['#2196F3', '#F44336'] if decoy_scores is not None and len(decoy_scores) > 0: data.append(decoy_scores) labels.append('Decoys\n(X+, Y~)') colors.append('#FF9800') parts = ax.violinplot(data, positions=range(len(data)), showmedians=True) for i, (pc, c) in enumerate(zip(parts['bodies'], colors)): pc.set_facecolor(c) pc.set_alpha(0.7) ax.set_xticks(range(len(data))) ax.set_xticklabels(labels) ax.set_ylabel('Q_theta Score', fontsize=12) ax.set_title(title, fontsize=14) ax.set_ylim(0, 1) ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Decision boundary') ax.legend() # Add mean + std annotations for i, (d, c) in enumerate(zip(data, colors)): ax.text(i, 0.02, f'μ={d.mean():.2f}\nσ={d.std():.2f}', ha='center', fontsize=9, color=c) plt.tight_layout() if outpath: plt.savefig(outpath, dpi=150, bbox_inches='tight') logger.info(f"Saved plot to {outpath}") plt.close() def plot_roc_curve(labels, scores, title='ROC Curve', outpath=None): """Plot ROC curve for positive vs negative classification.""" fpr, tpr, _ = roc_curve(labels, scores) auc = roc_auc_score(labels, scores) fig, ax = plt.subplots(figsize=(6, 6)) ax.plot(fpr, tpr, 'b-', lw=2, label=f'AUC = {auc:.3f}') ax.plot([0, 1], [0, 1], 'k--', lw=1) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title(title) ax.legend() plt.tight_layout() if outpath: plt.savefig(outpath, dpi=150, bbox_inches='tight') plt.close() return auc def plot_best_of_k(results, outpath=None): """Plot best-of-K success rate as a function of K.""" Ks = sorted(results.keys()) success_rates = [results[K] for K in Ks] fig, ax = plt.subplots(figsize=(8, 5)) ax.semilogx(Ks, success_rates, 'b-o', lw=2, markersize=8) ax.set_xlabel('K (number of candidates)', fontsize=12) ax.set_ylabel('Success rate (best score > 0.7)', fontsize=12) ax.set_title('Best-of-K Analysis', fontsize=14) ax.set_ylim(0, 1.05) ax.grid(True, alpha=0.3) ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% success') ax.legend() plt.tight_layout() if outpath: plt.savefig(outpath, dpi=150, bbox_inches='tight') plt.close() @torch.no_grad() def evaluate(model, loader, device): """Run model on a dataset and collect all predictions.""" model.eval() all_scores, all_labels, all_types, all_pdbs = [], [], [], [] for batch in loader: esm_feats = batch['esm_feats'].to(device) if 'esm_feats' in batch else None scores = model( batch['node_feats'].to(device), batch['edge_feats'].to(device), batch['node_mask'].to(device), esm_feats=esm_feats, ) all_scores.extend(scores.cpu().numpy().tolist()) all_labels.extend(batch['label'].numpy().tolist()) all_types.extend(batch['type']) all_pdbs.extend(batch['pdb']) return (np.array(all_scores), np.array(all_labels), np.array(all_types), np.array(all_pdbs)) def main(): parser = argparse.ArgumentParser(description='Evaluate Allo-Designer Q_theta scorer') parser.add_argument('--target', default='cam', help='Target name (cam, abl, era, or any custom target with data in data/processed/)') parser.add_argument('--all_targets', action='store_true', help='Evaluate on all available targets and produce aggregated results') parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint') parser.add_argument('--data_dir', default='data/processed') parser.add_argument('--split', choices=['val', 'test'], default='test') parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--gpu', type=int, default=7) parser.add_argument('--outdir', default='results') parser.add_argument('--bok_threshold', type=float, default=0.7, help='Score threshold for best-of-K (default 0.7; use per-target value for calibrated results)') parser.add_argument('--esm_dir', default=None, help='Path to ESM-2 embedding cache (auto-detected at /esm2_embeddings if omitted)') parser.add_argument('--no_wandb', action='store_true', help='(ignored; here for CLI compatibility)') args = parser.parse_args() # Auto-detect ESM dir under data_dir if args.esm_dir is None: cand = os.path.join(args.data_dir, 'esm2_embeddings') if os.path.isdir(cand): args.esm_dir = cand device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') os.makedirs(args.outdir, exist_ok=True) os.makedirs(f'{args.outdir}/figures', exist_ok=True) os.makedirs(f'{args.outdir}/tables', exist_ok=True) # Load model state = torch.load(args.checkpoint, map_location=device) config = state.get('config', {}) model = build_model(config).to(device) model.load_state_dict(state['model_state']) logger.info(f"Loaded model from {args.checkpoint}") # Load dataset data_path = os.path.join(args.data_dir, args.target, f'{args.split}.pkl') if not os.path.exists(data_path): logger.error(f"Data not found: {data_path}") sys.exit(1) dataset = TwoStateComplexDataset(data_path, max_nodes=128, esm_dir=args.esm_dir, target_name=args.target) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=collate_fn ) # Run evaluation logger.info(f"Evaluating on {len(dataset)} samples...") scores, labels, types, pdbs = evaluate(model, loader, device) # Separate by type pos_mask = (types == 'positive') neg_apo_mask = (types == 'negative_apo') decoy_mask = np.array(['decoy' in t for t in types]) pos_scores = scores[pos_mask] neg_scores = scores[neg_apo_mask] decoy_scores = scores[decoy_mask] logger.info(f"\n{'='*50}") logger.info(f"Results for {args.target} ({args.split})") logger.info(f"{'='*50}") logger.info(f"Positive samples: {pos_mask.sum()}") logger.info(f"Negative (apo) samples: {neg_apo_mask.sum()}") logger.info(f"Decoy samples: {decoy_mask.sum()}") # --- Core metrics --- metrics = {} # 1. Spearman correlation with DockQ labels sp, p_val = spearmanr(scores, labels) metrics['spearman_all'] = float(sp) metrics['spearman_pval'] = float(p_val) logger.info(f"\nSpearman(Q_theta, DockQ): {sp:.3f} (p={p_val:.3e})") # 2. Selectivity gap (positive vs negative_apo) if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0: gap = float(pos_scores.mean() - neg_scores.mean()) ranking_acc = float((pos_scores.mean() > neg_scores).mean() if len(neg_scores) > 0 else 0.5) metrics['selectivity_gap'] = gap metrics['pos_score_mean'] = float(pos_scores.mean()) metrics['neg_score_mean'] = float(neg_scores.mean()) metrics['pos_score_std'] = float(pos_scores.std()) metrics['neg_score_std'] = float(neg_scores.std()) logger.info(f"Selectivity gap (pos - neg): {gap:.3f}") logger.info(f" Pos: {pos_scores.mean():.3f} ± {pos_scores.std():.3f}") logger.info(f" Neg: {neg_scores.mean():.3f} ± {neg_scores.std():.3f}") # 3. AUC for positive vs negative if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0: pn_scores = np.concatenate([pos_scores, neg_scores]) pn_labels = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))]) auc = roc_auc_score(pn_labels, pn_scores) metrics['auc_pos_vs_neg'] = float(auc) logger.info(f"AUC (pos vs neg_apo): {auc:.3f}") # ROC curve plot_roc_curve( pn_labels, pn_scores, title=f'ROC: Positive vs Negative Apo ({args.target.upper()})', outpath=f'{args.outdir}/figures/roc_{args.target}_{args.split}.png' ) # 4. AUC for quality classification (DockQ > 0.5) binary = (labels > 0.5).astype(int) if binary.sum() > 0 and binary.sum() < len(binary): auc_quality = roc_auc_score(binary, scores) metrics['auc_quality'] = float(auc_quality) logger.info(f"AUC (quality>0.5): {auc_quality:.3f}") # 5. Best-of-K analysis if len(pos_scores) > 0: bok_results = compute_best_of_k(pos_scores, K_values=[1, 2, 5, 10, 20, 50], threshold=args.bok_threshold) metrics['best_of_k'] = {str(K): float(v) for K, v in bok_results.items()} logger.info(f"\nBest-of-K success rates:") for K, rate in bok_results.items(): logger.info(f" K={K:3d}: {rate:.3f}") plot_best_of_k( bok_results, outpath=f'{args.outdir}/figures/best_of_k_{args.target}_{args.split}.png' ) # 6. Score distributions plot plot_score_distributions( pos_scores if len(pos_scores) > 0 else np.array([]), neg_scores if len(neg_scores) > 0 else np.array([]), decoy_scores if len(decoy_scores) > 0 else None, title=f'Q_theta Score Distributions ({args.target.upper()})', outpath=f'{args.outdir}/figures/score_dist_{args.target}_{args.split}.png' ) # Save metrics out_json = f'{args.outdir}/tables/eval_{args.target}_{args.split}.json' with open(out_json, 'w') as f: json.dump(metrics, f, indent=2) logger.info(f"\nSaved metrics to {out_json}") # Print summary table logger.info(f"\n{'='*50}") logger.info("SUMMARY TABLE") logger.info(f"{'='*50}") logger.info(f"{'Metric':<30} {'Value':>10}") logger.info(f"{'-'*42}") for k, v in metrics.items(): if isinstance(v, float): logger.info(f"{k:<30} {v:>10.4f}") logger.info(f"{'='*50}") if __name__ == '__main__': main()