| """ |
| Q_theta: State-selectivity scorer for Allo-Designer. |
| |
| Architecture: Dense Edge-Biased Graph Transformer |
| - Input: padded interface graph (node feats + pairwise edge feats) |
| - SE(3)-invariant features (all features from distances/angles in backbone frames) |
| - Output: Q_theta(X, Y) in (0,1) = probability-like compatibility/selectivity score |
| |
| No torch_geometric dependency: uses dense attention with edge biases. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class RBFLayer(nn.Module): |
| """Learnable RBF embedding for edge distances.""" |
| def __init__(self, n_bins: int = 16, d_min: float = 0., d_max: float = 20.): |
| super().__init__() |
| centers = torch.linspace(d_min, d_max, n_bins) |
| self.register_buffer('centers', centers) |
| self.log_sigma = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, dist): |
| |
| sigma = torch.exp(self.log_sigma) |
| return torch.exp(-((dist.unsqueeze(-1) - self.centers) ** 2) / (2 * sigma ** 2)) |
|
|
|
|
| class EdgeBiasedMHA(nn.Module): |
| """ |
| Multi-Head Self-Attention with additive edge biases. |
| Implements the core equation: |
| A_ij = (Q_i K_j^T / sqrt(d)) + b_ij |
| where b_ij is computed from edge features. |
| """ |
| def __init__(self, d_model: int, n_heads: int, d_edge: int, dropout: float = 0.1): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.scale = math.sqrt(self.d_head) |
|
|
| self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.edge_proj = nn.Linear(d_edge, n_heads) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_feats, mask=None): |
| """ |
| x: [B, N, d_model] |
| edge_feats: [B, N, N, d_edge] |
| mask: [B, N] bool (True = valid residue) |
| """ |
| B, N, D = x.shape |
| H = self.n_heads |
|
|
| |
| qkv = self.qkv_proj(x).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(0) |
|
|
| |
| attn_logits = (q @ k.transpose(-2, -1)) / self.scale |
|
|
| |
| edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) |
| attn_logits = attn_logits + edge_bias |
|
|
| |
| if mask is not None: |
| |
| padding = ~mask |
| attn_logits = attn_logits.masked_fill( |
| padding[:, None, None, :], |
| float('-inf') |
| ) |
|
|
| attn_weights = self.dropout(F.softmax(attn_logits, dim=-1)) |
|
|
| |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0) |
|
|
| out = (attn_weights @ v) |
| out = out.transpose(1, 2).reshape(B, N, D) |
| return self.out_proj(out) |
|
|
|
|
| class InterfaceTransformerLayer(nn.Module): |
| """Single layer of edge-biased transformer with pre-norm.""" |
| def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): |
| super().__init__() |
| self.attn = EdgeBiasedMHA(d_model, n_heads, d_edge, dropout) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * ff_mult), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_model * ff_mult, d_model), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_feats, mask=None): |
| x = x + self.drop(self.attn(self.norm1(x), edge_feats, mask)) |
| x = x + self.drop(self.ff(self.norm2(x))) |
| return x |
|
|
|
|
| class GATLayer(nn.Module): |
| """Multi-head GAT layer with pre-norm. No edge features in attention.""" |
| def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, dropout: float = 0.1): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
|
|
| self.W = nn.Linear(d_model, d_model, bias=False) |
| self.a_l = nn.Parameter(torch.randn(n_heads, self.d_head)) |
| self.a_r = nn.Parameter(torch.randn(n_heads, self.d_head)) |
| nn.init.xavier_uniform_(self.a_l.unsqueeze(0)) |
| nn.init.xavier_uniform_(self.a_r.unsqueeze(0)) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.leaky_relu = nn.LeakyReLU(0.2) |
| self.attn_drop = nn.Dropout(dropout) |
|
|
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * ff_mult), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_feats, mask=None): |
| B, N, D = x.shape |
| H = self.n_heads |
|
|
| h = self.norm1(x) |
| Wh = self.W(h).view(B, N, H, self.d_head) |
| e_l = (Wh * self.a_l).sum(-1) |
| e_r = (Wh * self.a_r).sum(-1) |
| attn = self.leaky_relu(e_l.unsqueeze(2) + e_r.unsqueeze(1)) |
| attn = attn.permute(0, 3, 1, 2) |
|
|
| if mask is not None: |
| attn = attn.masked_fill(~mask[:, None, None, :], float('-inf')) |
|
|
| attn = self.attn_drop(F.softmax(attn, dim=-1)) |
| attn = torch.nan_to_num(attn, nan=0.0) |
|
|
| out = torch.einsum('bhnm,bmhd->bnhd', attn, Wh) |
| out = out.reshape(B, N, D) |
| x = x + self.drop(self.out_proj(out)) |
| x = x + self.drop(self.ff(self.norm2(x))) |
| return x |
|
|
|
|
| class GCNLayer(nn.Module): |
| """GCN layer with edge-weighted message passing and pre-norm.""" |
| def __init__(self, d_model: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): |
| super().__init__() |
| self.msg_proj = nn.Linear(d_model, d_model, bias=False) |
| self.edge_weight = nn.Linear(d_edge, 1) |
|
|
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * ff_mult), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_feats, mask=None): |
| B, N, D = x.shape |
| h = self.norm1(x) |
| msg = self.msg_proj(h) |
|
|
| w = self.edge_weight(edge_feats).squeeze(-1) |
| if mask is not None: |
| w = w.masked_fill(~mask[:, None, :], float('-inf')) |
| w = F.softmax(w, dim=-1) |
| w = torch.nan_to_num(w, nan=0.0) |
|
|
| agg = torch.bmm(w, msg) |
| x = x + self.drop(agg) |
| x = x + self.drop(self.ff(self.norm2(x))) |
| return x |
|
|
|
|
| class CrossChainTransformerLayer(nn.Module): |
| """Cross-chain attention: each node attends only to nodes from the other chain.""" |
| def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.scale = math.sqrt(self.d_head) |
|
|
| self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.edge_proj = nn.Linear(d_edge, n_heads) |
| self.attn_drop = nn.Dropout(dropout) |
|
|
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * ff_mult), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_feats, mask=None, chain_mask=None): |
| """ |
| x: [B, N, d_model] |
| edge_feats: [B, N, N, d_edge] |
| mask: [B, N] bool (True = valid) |
| chain_mask: [B, N] float (0=receptor, 1=binder) |
| """ |
| B, N, D = x.shape |
| H = self.n_heads |
|
|
| h = self.norm1(x) |
| qkv = self.qkv_proj(h).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(0) |
|
|
| attn_logits = (q @ k.transpose(-2, -1)) / self.scale |
| edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) |
| attn_logits = attn_logits + edge_bias |
|
|
| |
| if mask is not None: |
| attn_logits = attn_logits.masked_fill(~mask[:, None, None, :], float('-inf')) |
|
|
| |
| if chain_mask is not None: |
| same_chain = (chain_mask.unsqueeze(1) == chain_mask.unsqueeze(2)) |
| attn_logits = attn_logits.masked_fill(same_chain[:, None, :, :], float('-inf')) |
|
|
| attn_weights = self.attn_drop(F.softmax(attn_logits, dim=-1)) |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0) |
|
|
| out = (attn_weights @ v).transpose(1, 2).reshape(B, N, D) |
| x = x + self.drop(self.out_proj(out)) |
| x = x + self.drop(self.ff(self.norm2(x))) |
| return x |
|
|
|
|
| class EdgeUpdateLayer(nn.Module): |
| """Updates edge features using node representations each layer. |
| Memory-efficient: projects nodes to low-dim before outer product.""" |
| def __init__(self, d_model: int, d_edge: int, dropout: float = 0.1): |
| super().__init__() |
| d_proj = min(32, d_model // 4) |
| self.proj_i = nn.Linear(d_model, d_proj, bias=False) |
| self.proj_j = nn.Linear(d_model, d_proj, bias=False) |
| self.edge_mlp = nn.Sequential( |
| nn.Linear(2 * d_proj + d_edge, d_edge), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_edge, d_edge), |
| ) |
| self.norm = nn.LayerNorm(d_edge) |
|
|
| def forward(self, h, e, mask=None): |
| B, N, D = h.shape |
| hi = self.proj_i(h).unsqueeze(2).expand(-1, -1, N, -1) |
| hj = self.proj_j(h).unsqueeze(1).expand(-1, N, -1, -1) |
| inp = torch.cat([hi, hj, self.norm(e)], dim=-1) |
| e = e + self.edge_mlp(inp) |
| return e |
|
|
|
|
| class InterfaceGNN(nn.Module): |
| """ |
| Q_theta scorer: SE(3)-invariant dense graph transformer for interface scoring. |
| |
| Input: |
| node_feats: [B, N, node_dim] per-residue features |
| edge_feats: [B, N, N, edge_dim] pairwise edge features |
| mask: [B, N] bool (True = valid residue, False = padding) |
| |
| Output: |
| scores: [B] in (0, 1) = Q_theta(X, Y) |
| """ |
| def __init__( |
| self, |
| node_dim: int = 28, |
| edge_dim: int = 37, |
| hidden_dim: int = 128, |
| n_layers: int = 4, |
| n_heads: int = 8, |
| ff_mult: int = 4, |
| dropout: float = 0.1, |
| backbone: str = 'transformer', |
| pooling: str = 'meanmax', |
| edge_update: bool = False, |
| esm_dim: int = 0, |
| esm_proj_dim: int = 128, |
| esm_dropout: float = 0.0, |
| ): |
| super().__init__() |
| actual_node_dim = node_dim + (esm_proj_dim if esm_dim > 0 else 0) |
| self.esm_dim = esm_dim |
| if esm_dim > 0: |
| layers = [ |
| nn.Linear(esm_dim, esm_proj_dim), |
| nn.LayerNorm(esm_proj_dim), |
| nn.GELU(), |
| ] |
| if esm_dropout > 0: |
| layers.append(nn.Dropout(esm_dropout)) |
| self.esm_proj = nn.Sequential(*layers) |
| self.node_embed = nn.Sequential( |
| nn.Linear(actual_node_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| ) |
| self.edge_embed = nn.Sequential( |
| nn.Linear(edge_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| ) |
| d_edge_hidden = hidden_dim // 2 |
|
|
| if backbone == 'transformer': |
| self.layers = nn.ModuleList([ |
| InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout) |
| for _ in range(n_layers) |
| ]) |
| elif backbone == 'gat': |
| self.layers = nn.ModuleList([ |
| GATLayer(hidden_dim, n_heads, ff_mult, dropout) |
| for _ in range(n_layers) |
| ]) |
| elif backbone == 'gcn': |
| self.layers = nn.ModuleList([ |
| GCNLayer(hidden_dim, d_edge_hidden, ff_mult, dropout) |
| for _ in range(n_layers) |
| ]) |
| elif backbone == 'crosschain': |
| |
| layers = [] |
| for i in range(n_layers): |
| if i % 2 == 0: |
| layers.append(InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)) |
| else: |
| layers.append(CrossChainTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)) |
| self.layers = nn.ModuleList(layers) |
| else: |
| raise ValueError(f"Unknown backbone: {backbone}") |
|
|
| self.norm_out = nn.LayerNorm(hidden_dim) |
|
|
| |
| self.edge_update = edge_update |
| if edge_update: |
| self.edge_update_layers = nn.ModuleList([ |
| EdgeUpdateLayer(hidden_dim, d_edge_hidden, dropout) |
| for _ in range(n_layers) |
| ]) |
|
|
| |
| self.pooling = pooling |
| if pooling == 'attention': |
| self.attn_pool = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.Tanh(), |
| nn.Linear(hidden_dim // 2, 1), |
| ) |
| pool_dim = hidden_dim |
| else: |
| pool_dim = 2 * hidden_dim |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(pool_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Linear(hidden_dim // 2, 1), |
| ) |
|
|
| def forward(self, node_feats, edge_feats, mask, esm_feats=None): |
| """ |
| node_feats: [B, N, node_dim] |
| edge_feats: [B, N, N, edge_dim] |
| mask: [B, N] bool |
| esm_feats: [B, N, esm_dim] optional ESM-2 embeddings |
| Returns: scores [B] in (0, 1) |
| """ |
| B, N, _ = node_feats.shape |
|
|
| |
| chain_mask = node_feats[:, :, -1] |
|
|
| |
| if self.esm_dim > 0 and esm_feats is not None: |
| esm_proj = self.esm_proj(esm_feats) |
| node_feats = torch.cat([node_feats, esm_proj], dim=-1) |
|
|
| |
| h = self.node_embed(node_feats) |
| e = self.edge_embed(edge_feats) |
|
|
| |
| for i, layer in enumerate(self.layers): |
| if isinstance(layer, CrossChainTransformerLayer): |
| h = layer(h, e, mask, chain_mask=chain_mask) |
| else: |
| h = layer(h, e, mask) |
| if self.edge_update: |
| e = self.edge_update_layers[i](h, e, mask) |
|
|
| h = self.norm_out(h) |
|
|
| |
| mask_f = mask.float().unsqueeze(-1) |
|
|
| if self.pooling == 'attention': |
| |
| attn_logits = self.attn_pool(h).squeeze(-1) |
| attn_logits = attn_logits.masked_fill(~mask, float('-inf')) |
| attn_weights = F.softmax(attn_logits, dim=-1).unsqueeze(-1) |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0) |
| h_pool = (h * attn_weights).sum(dim=1) |
| else: |
| |
| h_masked = h * mask_f |
| h_mean = h_masked.sum(dim=1) / (mask_f.sum(dim=1) + 1e-8) |
| h_max_input = h_masked + (1 - mask_f) * (-1e9) |
| h_max = h_max_input.max(dim=1).values |
| h_pool = torch.cat([h_mean, h_max], dim=-1) |
|
|
| |
| logits = self.head(h_pool).squeeze(-1) |
| scores = torch.sigmoid(logits) |
| return scores |
|
|
|
|
| class AlloDesignerScorer(nn.Module): |
| """ |
| Full Q_theta model wrapper with loss computation. |
| |
| Implements the two-stage training objective: |
| Phase 1: DockQ regression (MSE loss) |
| Phase 2: Selectivity margin ranking (contrastive loss) |
| |
| The selectivity margin from the paper (Eq. 3): |
| S_theta(Y; X+, N) = logit(Q(X+, Y)) - log sum_X- exp(logit(Q(X-, Y))) |
| """ |
| def __init__(self, node_dim=28, edge_dim=37, hidden_dim=128, |
| n_layers=4, n_heads=8, dropout=0.1, backbone='transformer', |
| pooling='meanmax', edge_update=False, esm_dim=0, |
| esm_proj_dim=128, esm_dropout=0.0): |
| super().__init__() |
| self.gnn = InterfaceGNN(node_dim, edge_dim, hidden_dim, n_layers, n_heads, |
| dropout=dropout, backbone=backbone, |
| pooling=pooling, edge_update=edge_update, |
| esm_dim=esm_dim, esm_proj_dim=esm_proj_dim, |
| esm_dropout=esm_dropout) |
|
|
| def forward(self, node_feats, edge_feats, mask, esm_feats=None): |
| return self.gnn(node_feats, edge_feats, mask, esm_feats=esm_feats) |
|
|
| def compute_dockq_loss(self, scores, dockq_labels): |
| """Phase 1: MSE regression loss against DockQ labels.""" |
| return F.mse_loss(scores, dockq_labels.float()) |
|
|
| def compute_selectivity_loss(self, pos_scores, neg_scores_list, margin: float = 0.2): |
| """ |
| Phase 2: Selectivity margin loss. |
| |
| For each binder Y: |
| pos_score = Q(X+, Y) |
| neg_scores = [Q(X-, Y) for X- in N] |
| |
| Loss = -mean(S_theta) where |
| S_theta = logit(pos_score) - log sum exp(logit(neg_scores)) |
| |
| Also computes a soft margin loss: |
| L_margin = mean(max(0, margin - (pos_score - neg_score))) |
| """ |
| |
| eps = 1e-6 |
| pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) |
|
|
| |
| neg_logits = torch.stack([ |
| torch.log(s.clamp(eps, 1 - eps) / (1 - s).clamp(eps)) |
| for s in neg_scores_list |
| ], dim=-1) |
|
|
| |
| log_denom = torch.logsumexp(neg_logits, dim=-1) |
| selectivity = pos_logit - log_denom |
| selectivity_loss = -selectivity.mean() |
|
|
| |
| margin_losses = [] |
| for neg_scores in neg_scores_list: |
| margin_losses.append(F.relu(margin - (pos_scores - neg_scores))) |
| margin_loss = torch.stack(margin_losses, dim=-1).mean() |
|
|
| return selectivity_loss + margin_loss |
|
|
| def compute_path_selectivity_loss(self, pos_scores, neg_scores_list, |
| path_scores_list, path_taus, |
| margin=0.2, path_lambda=0.5): |
| """ |
| Extended selectivity loss with path monotonicity regularization. |
| |
| Args: |
| pos_scores: [B] Q(X1, Y) -- goal state scores |
| neg_scores_list: list of [B] -- Q(X0, Y), Q(X_cryptic, Y), etc. |
| path_scores_list: list of [B] -- Q(X_tau, Y) for each path frame |
| path_taus: list of float -- tau values for each path frame (sorted) |
| margin: margin for ranking loss |
| path_lambda: weight for path monotonicity loss |
| |
| Returns: |
| total_loss: selectivity loss + path_lambda * monotonicity loss |
| loss_dict: breakdown of loss components |
| """ |
| |
| select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list, margin) |
|
|
| |
| loss_monotone = torch.tensor(0.0, device=pos_scores.device) |
| if path_scores_list and path_lambda > 0: |
| small_margin = 0.05 |
| |
| for i in range(len(path_scores_list) - 1): |
| loss_monotone = loss_monotone + F.relu( |
| path_scores_list[i] - path_scores_list[i + 1] + small_margin |
| ).mean() |
| |
| loss_monotone = loss_monotone + F.relu( |
| path_scores_list[-1] - pos_scores + margin |
| ).mean() |
| |
| if neg_scores_list: |
| loss_monotone = loss_monotone + F.relu( |
| neg_scores_list[0] - path_scores_list[0] + small_margin |
| ).mean() |
|
|
| total = select_loss + path_lambda * loss_monotone |
| return total, { |
| 'loss_selectivity': select_loss.item(), |
| 'loss_path_monotone': loss_monotone.item(), |
| } |
|
|
| def compute_combined_loss(self, pos_scores, neg_scores_list, dockq_labels, |
| lambda_rank: float = 1.0): |
| """Combined Phase 1 + Phase 2 loss.""" |
| |
| dockq_loss = self.compute_dockq_loss(pos_scores, dockq_labels) |
|
|
| |
| select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list) |
|
|
| return dockq_loss + lambda_rank * select_loss, { |
| 'loss_dockq': dockq_loss.item(), |
| 'loss_selectivity': select_loss.item(), |
| } |
|
|
|
|
| def build_model(config: dict) -> AlloDesignerScorer: |
| """Build the Q_theta scorer from a config dict.""" |
| return AlloDesignerScorer( |
| node_dim=config.get('node_dim', 32), |
| edge_dim=config.get('edge_dim', 37), |
| hidden_dim=config.get('hidden_dim', 128), |
| n_layers=config.get('n_layers', 4), |
| n_heads=config.get('n_heads', 8), |
| dropout=config.get('dropout', 0.1), |
| backbone=config.get('backbone', 'transformer'), |
| pooling=config.get('pooling', 'meanmax'), |
| edge_update=config.get('edge_update', False), |
| esm_dim=config.get('esm_dim', 0), |
| esm_proj_dim=config.get('esm_proj_dim', 128), |
| esm_dropout=config.get('esm_dropout', 0.0), |
| ) |
|
|