AlloGen / code /models /scorer.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Q_theta: State-selectivity scorer for Allo-Designer.
Architecture: Dense Edge-Biased Graph Transformer
- Input: padded interface graph (node feats + pairwise edge feats)
- SE(3)-invariant features (all features from distances/angles in backbone frames)
- Output: Q_theta(X, Y) in (0,1) = probability-like compatibility/selectivity score
No torch_geometric dependency: uses dense attention with edge biases.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RBFLayer(nn.Module):
"""Learnable RBF embedding for edge distances."""
def __init__(self, n_bins: int = 16, d_min: float = 0., d_max: float = 20.):
super().__init__()
centers = torch.linspace(d_min, d_max, n_bins)
self.register_buffer('centers', centers)
self.log_sigma = nn.Parameter(torch.zeros(1))
def forward(self, dist):
# dist: [...] -> [..., n_bins]
sigma = torch.exp(self.log_sigma)
return torch.exp(-((dist.unsqueeze(-1) - self.centers) ** 2) / (2 * sigma ** 2))
class EdgeBiasedMHA(nn.Module):
"""
Multi-Head Self-Attention with additive edge biases.
Implements the core equation:
A_ij = (Q_i K_j^T / sqrt(d)) + b_ij
where b_ij is computed from edge features.
"""
def __init__(self, d_model: int, n_heads: int, d_edge: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.scale = math.sqrt(self.d_head)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model)
self.edge_proj = nn.Linear(d_edge, n_heads) # edge features -> per-head bias
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_feats, mask=None):
"""
x: [B, N, d_model]
edge_feats: [B, N, N, d_edge]
mask: [B, N] bool (True = valid residue)
"""
B, N, D = x.shape
H = self.n_heads
# QKV projection
qkv = self.qkv_proj(x).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
# Scaled dot-product attention logits
attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
# Edge bias: [B, N, N, H] -> [B, H, N, N]
edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
attn_logits = attn_logits + edge_bias
# Padding mask: mask out padded positions
if mask is not None:
# mask: [B, N] True=valid; padding=False
padding = ~mask # [B, N] True=padding
attn_logits = attn_logits.masked_fill(
padding[:, None, None, :], # [B, 1, 1, N]
float('-inf')
)
attn_weights = self.dropout(F.softmax(attn_logits, dim=-1))
# Handle all-padding rows (NaN -> 0)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
out = (attn_weights @ v) # [B, H, N, d_head]
out = out.transpose(1, 2).reshape(B, N, D) # [B, N, D]
return self.out_proj(out)
class InterfaceTransformerLayer(nn.Module):
"""Single layer of edge-biased transformer with pre-norm."""
def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
super().__init__()
self.attn = EdgeBiasedMHA(d_model, n_heads, d_edge, dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * ff_mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model * ff_mult, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, edge_feats, mask=None):
x = x + self.drop(self.attn(self.norm1(x), edge_feats, mask))
x = x + self.drop(self.ff(self.norm2(x)))
return x
class GATLayer(nn.Module):
"""Multi-head GAT layer with pre-norm. No edge features in attention."""
def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W = nn.Linear(d_model, d_model, bias=False)
self.a_l = nn.Parameter(torch.randn(n_heads, self.d_head))
self.a_r = nn.Parameter(torch.randn(n_heads, self.d_head))
nn.init.xavier_uniform_(self.a_l.unsqueeze(0))
nn.init.xavier_uniform_(self.a_r.unsqueeze(0))
self.out_proj = nn.Linear(d_model, d_model)
self.leaky_relu = nn.LeakyReLU(0.2)
self.attn_drop = nn.Dropout(dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, edge_feats, mask=None):
B, N, D = x.shape
H = self.n_heads
h = self.norm1(x)
Wh = self.W(h).view(B, N, H, self.d_head) # [B, N, H, d_head]
e_l = (Wh * self.a_l).sum(-1) # [B, N, H]
e_r = (Wh * self.a_r).sum(-1) # [B, N, H]
attn = self.leaky_relu(e_l.unsqueeze(2) + e_r.unsqueeze(1)) # [B, N, N, H]
attn = attn.permute(0, 3, 1, 2) # [B, H, N, N]
if mask is not None:
attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
attn = self.attn_drop(F.softmax(attn, dim=-1))
attn = torch.nan_to_num(attn, nan=0.0)
out = torch.einsum('bhnm,bmhd->bnhd', attn, Wh)
out = out.reshape(B, N, D)
x = x + self.drop(self.out_proj(out))
x = x + self.drop(self.ff(self.norm2(x)))
return x
class GCNLayer(nn.Module):
"""GCN layer with edge-weighted message passing and pre-norm."""
def __init__(self, d_model: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
super().__init__()
self.msg_proj = nn.Linear(d_model, d_model, bias=False)
self.edge_weight = nn.Linear(d_edge, 1)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, edge_feats, mask=None):
B, N, D = x.shape
h = self.norm1(x)
msg = self.msg_proj(h) # [B, N, D]
w = self.edge_weight(edge_feats).squeeze(-1) # [B, N, N]
if mask is not None:
w = w.masked_fill(~mask[:, None, :], float('-inf'))
w = F.softmax(w, dim=-1)
w = torch.nan_to_num(w, nan=0.0)
agg = torch.bmm(w, msg) # [B, N, D]
x = x + self.drop(agg)
x = x + self.drop(self.ff(self.norm2(x)))
return x
class CrossChainTransformerLayer(nn.Module):
"""Cross-chain attention: each node attends only to nodes from the other chain."""
def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.scale = math.sqrt(self.d_head)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model)
self.edge_proj = nn.Linear(d_edge, n_heads)
self.attn_drop = nn.Dropout(dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, edge_feats, mask=None, chain_mask=None):
"""
x: [B, N, d_model]
edge_feats: [B, N, N, d_edge]
mask: [B, N] bool (True = valid)
chain_mask: [B, N] float (0=receptor, 1=binder)
"""
B, N, D = x.shape
H = self.n_heads
h = self.norm1(x)
qkv = self.qkv_proj(h).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
attn_logits = attn_logits + edge_bias
# Mask padding
if mask is not None:
attn_logits = attn_logits.masked_fill(~mask[:, None, None, :], float('-inf'))
# Cross-chain mask: block same-chain attention
if chain_mask is not None:
same_chain = (chain_mask.unsqueeze(1) == chain_mask.unsqueeze(2)) # [B, N, N]
attn_logits = attn_logits.masked_fill(same_chain[:, None, :, :], float('-inf'))
attn_weights = self.attn_drop(F.softmax(attn_logits, dim=-1))
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
out = (attn_weights @ v).transpose(1, 2).reshape(B, N, D)
x = x + self.drop(self.out_proj(out))
x = x + self.drop(self.ff(self.norm2(x)))
return x
class EdgeUpdateLayer(nn.Module):
"""Updates edge features using node representations each layer.
Memory-efficient: projects nodes to low-dim before outer product."""
def __init__(self, d_model: int, d_edge: int, dropout: float = 0.1):
super().__init__()
d_proj = min(32, d_model // 4) # Low-dim projection to save memory
self.proj_i = nn.Linear(d_model, d_proj, bias=False)
self.proj_j = nn.Linear(d_model, d_proj, bias=False)
self.edge_mlp = nn.Sequential(
nn.Linear(2 * d_proj + d_edge, d_edge),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_edge, d_edge),
)
self.norm = nn.LayerNorm(d_edge)
def forward(self, h, e, mask=None):
B, N, D = h.shape
hi = self.proj_i(h).unsqueeze(2).expand(-1, -1, N, -1) # [B, N, N, d_proj]
hj = self.proj_j(h).unsqueeze(1).expand(-1, N, -1, -1) # [B, N, N, d_proj]
inp = torch.cat([hi, hj, self.norm(e)], dim=-1)
e = e + self.edge_mlp(inp)
return e
class InterfaceGNN(nn.Module):
"""
Q_theta scorer: SE(3)-invariant dense graph transformer for interface scoring.
Input:
node_feats: [B, N, node_dim] per-residue features
edge_feats: [B, N, N, edge_dim] pairwise edge features
mask: [B, N] bool (True = valid residue, False = padding)
Output:
scores: [B] in (0, 1) = Q_theta(X, Y)
"""
def __init__(
self,
node_dim: int = 28,
edge_dim: int = 37,
hidden_dim: int = 128,
n_layers: int = 4,
n_heads: int = 8,
ff_mult: int = 4,
dropout: float = 0.1,
backbone: str = 'transformer',
pooling: str = 'meanmax', # 'meanmax' or 'attention'
edge_update: bool = False,
esm_dim: int = 0, # 0 = no ESM; >0 = ESM embedding dim to project
esm_proj_dim: int = 128, # projection dim for ESM features
esm_dropout: float = 0.0, # dropout on ESM projection
):
super().__init__()
actual_node_dim = node_dim + (esm_proj_dim if esm_dim > 0 else 0)
self.esm_dim = esm_dim
if esm_dim > 0:
layers = [
nn.Linear(esm_dim, esm_proj_dim),
nn.LayerNorm(esm_proj_dim),
nn.GELU(),
]
if esm_dropout > 0:
layers.append(nn.Dropout(esm_dropout))
self.esm_proj = nn.Sequential(*layers)
self.node_embed = nn.Sequential(
nn.Linear(actual_node_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.edge_embed = nn.Sequential(
nn.Linear(edge_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim // 2),
)
d_edge_hidden = hidden_dim // 2
if backbone == 'transformer':
self.layers = nn.ModuleList([
InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)
for _ in range(n_layers)
])
elif backbone == 'gat':
self.layers = nn.ModuleList([
GATLayer(hidden_dim, n_heads, ff_mult, dropout)
for _ in range(n_layers)
])
elif backbone == 'gcn':
self.layers = nn.ModuleList([
GCNLayer(hidden_dim, d_edge_hidden, ff_mult, dropout)
for _ in range(n_layers)
])
elif backbone == 'crosschain':
# Interleave self-attention and cross-chain attention
layers = []
for i in range(n_layers):
if i % 2 == 0:
layers.append(InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
else:
layers.append(CrossChainTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
self.layers = nn.ModuleList(layers)
else:
raise ValueError(f"Unknown backbone: {backbone}")
self.norm_out = nn.LayerNorm(hidden_dim)
# Edge update layers (optional)
self.edge_update = edge_update
if edge_update:
self.edge_update_layers = nn.ModuleList([
EdgeUpdateLayer(hidden_dim, d_edge_hidden, dropout)
for _ in range(n_layers)
])
# Pooling
self.pooling = pooling
if pooling == 'attention':
self.attn_pool = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.Tanh(),
nn.Linear(hidden_dim // 2, 1),
)
pool_dim = hidden_dim
else:
pool_dim = 2 * hidden_dim
# Scoring head
self.head = nn.Sequential(
nn.Linear(pool_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, node_feats, edge_feats, mask, esm_feats=None):
"""
node_feats: [B, N, node_dim]
edge_feats: [B, N, N, edge_dim]
mask: [B, N] bool
esm_feats: [B, N, esm_dim] optional ESM-2 embeddings
Returns: scores [B] in (0, 1)
"""
B, N, _ = node_feats.shape
# Extract chain mask for cross-chain attention (last dim = chain indicator)
chain_mask = node_feats[:, :, -1] # [B, N] float: 0=receptor, 1=binder
# Optionally concatenate projected ESM features
if self.esm_dim > 0 and esm_feats is not None:
esm_proj = self.esm_proj(esm_feats) # [B, N, 128]
node_feats = torch.cat([node_feats, esm_proj], dim=-1)
# Embed nodes and edges
h = self.node_embed(node_feats) # [B, N, hidden_dim]
e = self.edge_embed(edge_feats) # [B, N, N, hidden_dim//2]
# Graph transformer layers (with optional edge updates)
for i, layer in enumerate(self.layers):
if isinstance(layer, CrossChainTransformerLayer):
h = layer(h, e, mask, chain_mask=chain_mask)
else:
h = layer(h, e, mask)
if self.edge_update:
e = self.edge_update_layers[i](h, e, mask)
h = self.norm_out(h) # [B, N, hidden_dim]
# Pooling
mask_f = mask.float().unsqueeze(-1) # [B, N, 1]
if self.pooling == 'attention':
# Learned attention pooling
attn_logits = self.attn_pool(h).squeeze(-1) # [B, N]
attn_logits = attn_logits.masked_fill(~mask, float('-inf'))
attn_weights = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, N, 1]
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
h_pool = (h * attn_weights).sum(dim=1) # [B, hidden_dim]
else:
# Mean + max pooling
h_masked = h * mask_f
h_mean = h_masked.sum(dim=1) / (mask_f.sum(dim=1) + 1e-8)
h_max_input = h_masked + (1 - mask_f) * (-1e9)
h_max = h_max_input.max(dim=1).values
h_pool = torch.cat([h_mean, h_max], dim=-1) # [B, 2*hidden_dim]
# Score
logits = self.head(h_pool).squeeze(-1) # [B]
scores = torch.sigmoid(logits) # [B] in (0, 1)
return scores
class AlloDesignerScorer(nn.Module):
"""
Full Q_theta model wrapper with loss computation.
Implements the two-stage training objective:
Phase 1: DockQ regression (MSE loss)
Phase 2: Selectivity margin ranking (contrastive loss)
The selectivity margin from the paper (Eq. 3):
S_theta(Y; X+, N) = logit(Q(X+, Y)) - log sum_X- exp(logit(Q(X-, Y)))
"""
def __init__(self, node_dim=28, edge_dim=37, hidden_dim=128,
n_layers=4, n_heads=8, dropout=0.1, backbone='transformer',
pooling='meanmax', edge_update=False, esm_dim=0,
esm_proj_dim=128, esm_dropout=0.0):
super().__init__()
self.gnn = InterfaceGNN(node_dim, edge_dim, hidden_dim, n_layers, n_heads,
dropout=dropout, backbone=backbone,
pooling=pooling, edge_update=edge_update,
esm_dim=esm_dim, esm_proj_dim=esm_proj_dim,
esm_dropout=esm_dropout)
def forward(self, node_feats, edge_feats, mask, esm_feats=None):
return self.gnn(node_feats, edge_feats, mask, esm_feats=esm_feats)
def compute_dockq_loss(self, scores, dockq_labels):
"""Phase 1: MSE regression loss against DockQ labels."""
return F.mse_loss(scores, dockq_labels.float())
def compute_selectivity_loss(self, pos_scores, neg_scores_list, margin: float = 0.2):
"""
Phase 2: Selectivity margin loss.
For each binder Y:
pos_score = Q(X+, Y)
neg_scores = [Q(X-, Y) for X- in N]
Loss = -mean(S_theta) where
S_theta = logit(pos_score) - log sum exp(logit(neg_scores))
Also computes a soft margin loss:
L_margin = mean(max(0, margin - (pos_score - neg_score)))
"""
# logit = log(p / (1-p))
eps = 1e-6
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
# neg_scores_list: list of [B] tensors
neg_logits = torch.stack([
torch.log(s.clamp(eps, 1 - eps) / (1 - s).clamp(eps))
for s in neg_scores_list
], dim=-1) # [B, n_neg]
# InfoNCE-style selectivity margin
log_denom = torch.logsumexp(neg_logits, dim=-1) # [B]
selectivity = pos_logit - log_denom # [B]
selectivity_loss = -selectivity.mean()
# Soft margin loss (averaged over all negatives)
margin_losses = []
for neg_scores in neg_scores_list:
margin_losses.append(F.relu(margin - (pos_scores - neg_scores)))
margin_loss = torch.stack(margin_losses, dim=-1).mean()
return selectivity_loss + margin_loss
def compute_path_selectivity_loss(self, pos_scores, neg_scores_list,
path_scores_list, path_taus,
margin=0.2, path_lambda=0.5):
"""
Extended selectivity loss with path monotonicity regularization.
Args:
pos_scores: [B] Q(X1, Y) -- goal state scores
neg_scores_list: list of [B] -- Q(X0, Y), Q(X_cryptic, Y), etc.
path_scores_list: list of [B] -- Q(X_tau, Y) for each path frame
path_taus: list of float -- tau values for each path frame (sorted)
margin: margin for ranking loss
path_lambda: weight for path monotonicity loss
Returns:
total_loss: selectivity loss + path_lambda * monotonicity loss
loss_dict: breakdown of loss components
"""
# Standard selectivity loss (unchanged)
select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list, margin)
# Path monotonicity loss: ensure Q increases with tau
loss_monotone = torch.tensor(0.0, device=pos_scores.device)
if path_scores_list and path_lambda > 0:
small_margin = 0.05
# Consecutive path frames should be monotonically increasing
for i in range(len(path_scores_list) - 1):
loss_monotone = loss_monotone + F.relu(
path_scores_list[i] - path_scores_list[i + 1] + small_margin
).mean()
# Last path frame should be less than positive (holo) score
loss_monotone = loss_monotone + F.relu(
path_scores_list[-1] - pos_scores + margin
).mean()
# First path frame should be greater than negative (apo) score
if neg_scores_list:
loss_monotone = loss_monotone + F.relu(
neg_scores_list[0] - path_scores_list[0] + small_margin
).mean()
total = select_loss + path_lambda * loss_monotone
return total, {
'loss_selectivity': select_loss.item(),
'loss_path_monotone': loss_monotone.item(),
}
def compute_combined_loss(self, pos_scores, neg_scores_list, dockq_labels,
lambda_rank: float = 1.0):
"""Combined Phase 1 + Phase 2 loss."""
# Regression loss on all scores (pos + neg get appropriate labels)
dockq_loss = self.compute_dockq_loss(pos_scores, dockq_labels)
# Selectivity loss
select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list)
return dockq_loss + lambda_rank * select_loss, {
'loss_dockq': dockq_loss.item(),
'loss_selectivity': select_loss.item(),
}
def build_model(config: dict) -> AlloDesignerScorer:
"""Build the Q_theta scorer from a config dict."""
return AlloDesignerScorer(
node_dim=config.get('node_dim', 32),
edge_dim=config.get('edge_dim', 37),
hidden_dim=config.get('hidden_dim', 128),
n_layers=config.get('n_layers', 4),
n_heads=config.get('n_heads', 8),
dropout=config.get('dropout', 0.1),
backbone=config.get('backbone', 'transformer'),
pooling=config.get('pooling', 'meanmax'),
edge_update=config.get('edge_update', False),
esm_dim=config.get('esm_dim', 0),
esm_proj_dim=config.get('esm_proj_dim', 128),
esm_dropout=config.get('esm_dropout', 0.0),
)