""" Q_theta: State-selectivity scorer for Allo-Designer. Architecture: Dense Edge-Biased Graph Transformer - Input: padded interface graph (node feats + pairwise edge feats) - SE(3)-invariant features (all features from distances/angles in backbone frames) - Output: Q_theta(X, Y) in (0,1) = probability-like compatibility/selectivity score No torch_geometric dependency: uses dense attention with edge biases. """ import torch import torch.nn as nn import torch.nn.functional as F import math class RBFLayer(nn.Module): """Learnable RBF embedding for edge distances.""" def __init__(self, n_bins: int = 16, d_min: float = 0., d_max: float = 20.): super().__init__() centers = torch.linspace(d_min, d_max, n_bins) self.register_buffer('centers', centers) self.log_sigma = nn.Parameter(torch.zeros(1)) def forward(self, dist): # dist: [...] -> [..., n_bins] sigma = torch.exp(self.log_sigma) return torch.exp(-((dist.unsqueeze(-1) - self.centers) ** 2) / (2 * sigma ** 2)) class EdgeBiasedMHA(nn.Module): """ Multi-Head Self-Attention with additive edge biases. Implements the core equation: A_ij = (Q_i K_j^T / sqrt(d)) + b_ij where b_ij is computed from edge features. """ def __init__(self, d_model: int, n_heads: int, d_edge: int, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.scale = math.sqrt(self.d_head) self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model) self.edge_proj = nn.Linear(d_edge, n_heads) # edge features -> per-head bias self.dropout = nn.Dropout(dropout) def forward(self, x, edge_feats, mask=None): """ x: [B, N, d_model] edge_feats: [B, N, N, d_edge] mask: [B, N] bool (True = valid residue) """ B, N, D = x.shape H = self.n_heads # QKV projection qkv = self.qkv_proj(x).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # each [B, H, N, d_head] # Scaled dot-product attention logits attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N] # Edge bias: [B, N, N, H] -> [B, H, N, N] edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N] attn_logits = attn_logits + edge_bias # Padding mask: mask out padded positions if mask is not None: # mask: [B, N] True=valid; padding=False padding = ~mask # [B, N] True=padding attn_logits = attn_logits.masked_fill( padding[:, None, None, :], # [B, 1, 1, N] float('-inf') ) attn_weights = self.dropout(F.softmax(attn_logits, dim=-1)) # Handle all-padding rows (NaN -> 0) attn_weights = torch.nan_to_num(attn_weights, nan=0.0) out = (attn_weights @ v) # [B, H, N, d_head] out = out.transpose(1, 2).reshape(B, N, D) # [B, N, D] return self.out_proj(out) class InterfaceTransformerLayer(nn.Module): """Single layer of edge-biased transformer with pre-norm.""" def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): super().__init__() self.attn = EdgeBiasedMHA(d_model, n_heads, d_edge, dropout) self.ff = nn.Sequential( nn.Linear(d_model, d_model * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) def forward(self, x, edge_feats, mask=None): x = x + self.drop(self.attn(self.norm1(x), edge_feats, mask)) x = x + self.drop(self.ff(self.norm2(x))) return x class GATLayer(nn.Module): """Multi-head GAT layer with pre-norm. No edge features in attention.""" def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.W = nn.Linear(d_model, d_model, bias=False) self.a_l = nn.Parameter(torch.randn(n_heads, self.d_head)) self.a_r = nn.Parameter(torch.randn(n_heads, self.d_head)) nn.init.xavier_uniform_(self.a_l.unsqueeze(0)) nn.init.xavier_uniform_(self.a_r.unsqueeze(0)) self.out_proj = nn.Linear(d_model, d_model) self.leaky_relu = nn.LeakyReLU(0.2) self.attn_drop = nn.Dropout(dropout) self.ff = nn.Sequential( nn.Linear(d_model, d_model * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) def forward(self, x, edge_feats, mask=None): B, N, D = x.shape H = self.n_heads h = self.norm1(x) Wh = self.W(h).view(B, N, H, self.d_head) # [B, N, H, d_head] e_l = (Wh * self.a_l).sum(-1) # [B, N, H] e_r = (Wh * self.a_r).sum(-1) # [B, N, H] attn = self.leaky_relu(e_l.unsqueeze(2) + e_r.unsqueeze(1)) # [B, N, N, H] attn = attn.permute(0, 3, 1, 2) # [B, H, N, N] if mask is not None: attn = attn.masked_fill(~mask[:, None, None, :], float('-inf')) attn = self.attn_drop(F.softmax(attn, dim=-1)) attn = torch.nan_to_num(attn, nan=0.0) out = torch.einsum('bhnm,bmhd->bnhd', attn, Wh) out = out.reshape(B, N, D) x = x + self.drop(self.out_proj(out)) x = x + self.drop(self.ff(self.norm2(x))) return x class GCNLayer(nn.Module): """GCN layer with edge-weighted message passing and pre-norm.""" def __init__(self, d_model: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): super().__init__() self.msg_proj = nn.Linear(d_model, d_model, bias=False) self.edge_weight = nn.Linear(d_edge, 1) self.ff = nn.Sequential( nn.Linear(d_model, d_model * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) def forward(self, x, edge_feats, mask=None): B, N, D = x.shape h = self.norm1(x) msg = self.msg_proj(h) # [B, N, D] w = self.edge_weight(edge_feats).squeeze(-1) # [B, N, N] if mask is not None: w = w.masked_fill(~mask[:, None, :], float('-inf')) w = F.softmax(w, dim=-1) w = torch.nan_to_num(w, nan=0.0) agg = torch.bmm(w, msg) # [B, N, D] x = x + self.drop(agg) x = x + self.drop(self.ff(self.norm2(x))) return x class CrossChainTransformerLayer(nn.Module): """Cross-chain attention: each node attends only to nodes from the other chain.""" def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.scale = math.sqrt(self.d_head) self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model) self.edge_proj = nn.Linear(d_edge, n_heads) self.attn_drop = nn.Dropout(dropout) self.ff = nn.Sequential( nn.Linear(d_model, d_model * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) def forward(self, x, edge_feats, mask=None, chain_mask=None): """ x: [B, N, d_model] edge_feats: [B, N, N, d_edge] mask: [B, N] bool (True = valid) chain_mask: [B, N] float (0=receptor, 1=binder) """ B, N, D = x.shape H = self.n_heads h = self.norm1(x) qkv = self.qkv_proj(h).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # each [B, H, N, d_head] attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N] edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N] attn_logits = attn_logits + edge_bias # Mask padding if mask is not None: attn_logits = attn_logits.masked_fill(~mask[:, None, None, :], float('-inf')) # Cross-chain mask: block same-chain attention if chain_mask is not None: same_chain = (chain_mask.unsqueeze(1) == chain_mask.unsqueeze(2)) # [B, N, N] attn_logits = attn_logits.masked_fill(same_chain[:, None, :, :], float('-inf')) attn_weights = self.attn_drop(F.softmax(attn_logits, dim=-1)) attn_weights = torch.nan_to_num(attn_weights, nan=0.0) out = (attn_weights @ v).transpose(1, 2).reshape(B, N, D) x = x + self.drop(self.out_proj(out)) x = x + self.drop(self.ff(self.norm2(x))) return x class EdgeUpdateLayer(nn.Module): """Updates edge features using node representations each layer. Memory-efficient: projects nodes to low-dim before outer product.""" def __init__(self, d_model: int, d_edge: int, dropout: float = 0.1): super().__init__() d_proj = min(32, d_model // 4) # Low-dim projection to save memory self.proj_i = nn.Linear(d_model, d_proj, bias=False) self.proj_j = nn.Linear(d_model, d_proj, bias=False) self.edge_mlp = nn.Sequential( nn.Linear(2 * d_proj + d_edge, d_edge), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_edge, d_edge), ) self.norm = nn.LayerNorm(d_edge) def forward(self, h, e, mask=None): B, N, D = h.shape hi = self.proj_i(h).unsqueeze(2).expand(-1, -1, N, -1) # [B, N, N, d_proj] hj = self.proj_j(h).unsqueeze(1).expand(-1, N, -1, -1) # [B, N, N, d_proj] inp = torch.cat([hi, hj, self.norm(e)], dim=-1) e = e + self.edge_mlp(inp) return e class InterfaceGNN(nn.Module): """ Q_theta scorer: SE(3)-invariant dense graph transformer for interface scoring. Input: node_feats: [B, N, node_dim] per-residue features edge_feats: [B, N, N, edge_dim] pairwise edge features mask: [B, N] bool (True = valid residue, False = padding) Output: scores: [B] in (0, 1) = Q_theta(X, Y) """ def __init__( self, node_dim: int = 28, edge_dim: int = 37, hidden_dim: int = 128, n_layers: int = 4, n_heads: int = 8, ff_mult: int = 4, dropout: float = 0.1, backbone: str = 'transformer', pooling: str = 'meanmax', # 'meanmax' or 'attention' edge_update: bool = False, esm_dim: int = 0, # 0 = no ESM; >0 = ESM embedding dim to project esm_proj_dim: int = 128, # projection dim for ESM features esm_dropout: float = 0.0, # dropout on ESM projection ): super().__init__() actual_node_dim = node_dim + (esm_proj_dim if esm_dim > 0 else 0) self.esm_dim = esm_dim if esm_dim > 0: layers = [ nn.Linear(esm_dim, esm_proj_dim), nn.LayerNorm(esm_proj_dim), nn.GELU(), ] if esm_dropout > 0: layers.append(nn.Dropout(esm_dropout)) self.esm_proj = nn.Sequential(*layers) self.node_embed = nn.Sequential( nn.Linear(actual_node_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) self.edge_embed = nn.Sequential( nn.Linear(edge_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim // 2), ) d_edge_hidden = hidden_dim // 2 if backbone == 'transformer': self.layers = nn.ModuleList([ InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout) for _ in range(n_layers) ]) elif backbone == 'gat': self.layers = nn.ModuleList([ GATLayer(hidden_dim, n_heads, ff_mult, dropout) for _ in range(n_layers) ]) elif backbone == 'gcn': self.layers = nn.ModuleList([ GCNLayer(hidden_dim, d_edge_hidden, ff_mult, dropout) for _ in range(n_layers) ]) elif backbone == 'crosschain': # Interleave self-attention and cross-chain attention layers = [] for i in range(n_layers): if i % 2 == 0: layers.append(InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)) else: layers.append(CrossChainTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)) self.layers = nn.ModuleList(layers) else: raise ValueError(f"Unknown backbone: {backbone}") self.norm_out = nn.LayerNorm(hidden_dim) # Edge update layers (optional) self.edge_update = edge_update if edge_update: self.edge_update_layers = nn.ModuleList([ EdgeUpdateLayer(hidden_dim, d_edge_hidden, dropout) for _ in range(n_layers) ]) # Pooling self.pooling = pooling if pooling == 'attention': self.attn_pool = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.Tanh(), nn.Linear(hidden_dim // 2, 1), ) pool_dim = hidden_dim else: pool_dim = 2 * hidden_dim # Scoring head self.head = nn.Sequential( nn.Linear(pool_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1), ) def forward(self, node_feats, edge_feats, mask, esm_feats=None): """ node_feats: [B, N, node_dim] edge_feats: [B, N, N, edge_dim] mask: [B, N] bool esm_feats: [B, N, esm_dim] optional ESM-2 embeddings Returns: scores [B] in (0, 1) """ B, N, _ = node_feats.shape # Extract chain mask for cross-chain attention (last dim = chain indicator) chain_mask = node_feats[:, :, -1] # [B, N] float: 0=receptor, 1=binder # Optionally concatenate projected ESM features if self.esm_dim > 0 and esm_feats is not None: esm_proj = self.esm_proj(esm_feats) # [B, N, 128] node_feats = torch.cat([node_feats, esm_proj], dim=-1) # Embed nodes and edges h = self.node_embed(node_feats) # [B, N, hidden_dim] e = self.edge_embed(edge_feats) # [B, N, N, hidden_dim//2] # Graph transformer layers (with optional edge updates) for i, layer in enumerate(self.layers): if isinstance(layer, CrossChainTransformerLayer): h = layer(h, e, mask, chain_mask=chain_mask) else: h = layer(h, e, mask) if self.edge_update: e = self.edge_update_layers[i](h, e, mask) h = self.norm_out(h) # [B, N, hidden_dim] # Pooling mask_f = mask.float().unsqueeze(-1) # [B, N, 1] if self.pooling == 'attention': # Learned attention pooling attn_logits = self.attn_pool(h).squeeze(-1) # [B, N] attn_logits = attn_logits.masked_fill(~mask, float('-inf')) attn_weights = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, N, 1] attn_weights = torch.nan_to_num(attn_weights, nan=0.0) h_pool = (h * attn_weights).sum(dim=1) # [B, hidden_dim] else: # Mean + max pooling h_masked = h * mask_f h_mean = h_masked.sum(dim=1) / (mask_f.sum(dim=1) + 1e-8) h_max_input = h_masked + (1 - mask_f) * (-1e9) h_max = h_max_input.max(dim=1).values h_pool = torch.cat([h_mean, h_max], dim=-1) # [B, 2*hidden_dim] # Score logits = self.head(h_pool).squeeze(-1) # [B] scores = torch.sigmoid(logits) # [B] in (0, 1) return scores class AlloDesignerScorer(nn.Module): """ Full Q_theta model wrapper with loss computation. Implements the two-stage training objective: Phase 1: DockQ regression (MSE loss) Phase 2: Selectivity margin ranking (contrastive loss) The selectivity margin from the paper (Eq. 3): S_theta(Y; X+, N) = logit(Q(X+, Y)) - log sum_X- exp(logit(Q(X-, Y))) """ def __init__(self, node_dim=28, edge_dim=37, hidden_dim=128, n_layers=4, n_heads=8, dropout=0.1, backbone='transformer', pooling='meanmax', edge_update=False, esm_dim=0, esm_proj_dim=128, esm_dropout=0.0): super().__init__() self.gnn = InterfaceGNN(node_dim, edge_dim, hidden_dim, n_layers, n_heads, dropout=dropout, backbone=backbone, pooling=pooling, edge_update=edge_update, esm_dim=esm_dim, esm_proj_dim=esm_proj_dim, esm_dropout=esm_dropout) def forward(self, node_feats, edge_feats, mask, esm_feats=None): return self.gnn(node_feats, edge_feats, mask, esm_feats=esm_feats) def compute_dockq_loss(self, scores, dockq_labels): """Phase 1: MSE regression loss against DockQ labels.""" return F.mse_loss(scores, dockq_labels.float()) def compute_selectivity_loss(self, pos_scores, neg_scores_list, margin: float = 0.2): """ Phase 2: Selectivity margin loss. For each binder Y: pos_score = Q(X+, Y) neg_scores = [Q(X-, Y) for X- in N] Loss = -mean(S_theta) where S_theta = logit(pos_score) - log sum exp(logit(neg_scores)) Also computes a soft margin loss: L_margin = mean(max(0, margin - (pos_score - neg_score))) """ # logit = log(p / (1-p)) eps = 1e-6 pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) # neg_scores_list: list of [B] tensors neg_logits = torch.stack([ torch.log(s.clamp(eps, 1 - eps) / (1 - s).clamp(eps)) for s in neg_scores_list ], dim=-1) # [B, n_neg] # InfoNCE-style selectivity margin log_denom = torch.logsumexp(neg_logits, dim=-1) # [B] selectivity = pos_logit - log_denom # [B] selectivity_loss = -selectivity.mean() # Soft margin loss (averaged over all negatives) margin_losses = [] for neg_scores in neg_scores_list: margin_losses.append(F.relu(margin - (pos_scores - neg_scores))) margin_loss = torch.stack(margin_losses, dim=-1).mean() return selectivity_loss + margin_loss def compute_path_selectivity_loss(self, pos_scores, neg_scores_list, path_scores_list, path_taus, margin=0.2, path_lambda=0.5): """ Extended selectivity loss with path monotonicity regularization. Args: pos_scores: [B] Q(X1, Y) -- goal state scores neg_scores_list: list of [B] -- Q(X0, Y), Q(X_cryptic, Y), etc. path_scores_list: list of [B] -- Q(X_tau, Y) for each path frame path_taus: list of float -- tau values for each path frame (sorted) margin: margin for ranking loss path_lambda: weight for path monotonicity loss Returns: total_loss: selectivity loss + path_lambda * monotonicity loss loss_dict: breakdown of loss components """ # Standard selectivity loss (unchanged) select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list, margin) # Path monotonicity loss: ensure Q increases with tau loss_monotone = torch.tensor(0.0, device=pos_scores.device) if path_scores_list and path_lambda > 0: small_margin = 0.05 # Consecutive path frames should be monotonically increasing for i in range(len(path_scores_list) - 1): loss_monotone = loss_monotone + F.relu( path_scores_list[i] - path_scores_list[i + 1] + small_margin ).mean() # Last path frame should be less than positive (holo) score loss_monotone = loss_monotone + F.relu( path_scores_list[-1] - pos_scores + margin ).mean() # First path frame should be greater than negative (apo) score if neg_scores_list: loss_monotone = loss_monotone + F.relu( neg_scores_list[0] - path_scores_list[0] + small_margin ).mean() total = select_loss + path_lambda * loss_monotone return total, { 'loss_selectivity': select_loss.item(), 'loss_path_monotone': loss_monotone.item(), } def compute_combined_loss(self, pos_scores, neg_scores_list, dockq_labels, lambda_rank: float = 1.0): """Combined Phase 1 + Phase 2 loss.""" # Regression loss on all scores (pos + neg get appropriate labels) dockq_loss = self.compute_dockq_loss(pos_scores, dockq_labels) # Selectivity loss select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list) return dockq_loss + lambda_rank * select_loss, { 'loss_dockq': dockq_loss.item(), 'loss_selectivity': select_loss.item(), } def build_model(config: dict) -> AlloDesignerScorer: """Build the Q_theta scorer from a config dict.""" return AlloDesignerScorer( node_dim=config.get('node_dim', 32), edge_dim=config.get('edge_dim', 37), hidden_dim=config.get('hidden_dim', 128), n_layers=config.get('n_layers', 4), n_heads=config.get('n_heads', 8), dropout=config.get('dropout', 0.1), backbone=config.get('backbone', 'transformer'), pooling=config.get('pooling', 'meanmax'), edge_update=config.get('edge_update', False), esm_dim=config.get('esm_dim', 0), esm_proj_dim=config.get('esm_proj_dim', 128), esm_dropout=config.get('esm_dropout', 0.0), )