AlloGen / code /models /features.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
SE(3)-invariant feature extraction for interface graphs.
Node and edge features used by the Q_theta scorer.
"""
import os
import sys
import numpy as np
# Ensure utils is importable (for both direct and package imports)
_CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _CODE_DIR not in sys.path:
sys.path.insert(0, _CODE_DIR)
from utils.pdb_utils import (
rbf_encode, compute_backbone_frames, compute_torsion_angles,
get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA
)
# Feature dimensions
# one-hot AA (21) + backbone torsions (6) + chi1 sin/cos (2) + chi2 sin/cos (2) + chain indicator (1) = 32
NODE_DIM = NUM_AA + 6 + 4 + 1 # = 32
EDGE_DIM = 16 + 3 + 9 + 8 + 1 # RBF dist (16) + direction (3) + rel rotation (9) + seq sep (8) + same chain (1) = 37
MAX_SEQ_SEP = 32 # bins for sequence separation
def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP):
"""Bin-encode sequence separation."""
bins = np.linspace(-max_sep, max_sep, n_bins + 1)
sep_clipped = np.clip(sep, -max_sep, max_sep)
encoded = np.zeros(n_bins, dtype=np.float32)
bin_idx = np.digitize(sep_clipped, bins) - 1
bin_idx = np.clip(bin_idx, 0, n_bins - 1)
encoded[bin_idx] = 1.0
return encoded
def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id):
"""
Compute per-residue node features.
Args:
residues: list of Bio.PDB residues
coords: [N, 4, 3] backbone coords
mask: [N] bool
torsion_angles: [N, 6] sin/cos of phi, psi, omega
chi_angles: [N, 4] sin/cos of chi1, chi2
chain_id: 0 = receptor, 1 = binder
Returns:
node_feats: [N, NODE_DIM] (NODE_DIM = 32)
"""
N = len(residues)
aa_idx = get_aa_indices(residues)
# One-hot amino acid
aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32)
for i in range(N):
if mask[i]:
aa_onehot[i, aa_idx[i]] = 1.0
# Chain indicator
chain_feat = np.full((N, 1), chain_id, dtype=np.float32)
# Concatenate
node_feats = np.concatenate([
aa_onehot, # [N, 21]
torsion_angles, # [N, 6]
chi_angles, # [N, 4]
chain_feat, # [N, 1]
], axis=-1)
return node_feats # [N, 32]
def extract_edge_features(coords_i, frames_i, coords_j, frames_j,
seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j):
"""
Compute SE(3)-invariant edge features between residue sets i and j.
Vectorized over all pairs.
Args:
coords_i: [N_i, 4, 3] backbone coords of set i (full interface)
frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3])
coords_j: [N_j, 4, 3]
frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3])
seq_idx_i: [N_i] integer sequence indices (for sequence separation)
seq_idx_j: [N_j] integer sequence indices
chain_i: int (0 or 1)
chain_j: int (0 or 1)
mask_i: [N_i] bool
mask_j: [N_j] bool
Returns:
edge_feats: [N_i, N_j, EDGE_DIM]
"""
N_i, N_j = len(coords_i), len(coords_j)
origins_i, rotations_i = frames_i
origins_j, rotations_j = frames_j
ca_i = origins_i # [N_i, 3]
ca_j = origins_j # [N_j, 3]
# --- Distance features ---
diff = ca_j[None, :, :] - ca_i[:, None, :] # [N_i, N_j, 3]
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_i, N_j]
dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16) # [N_i, N_j, 16]
# --- Direction in local frame of i ---
# unit vector from i to j in global frame
unit_diff = diff / (dist[..., None] + 1e-8) # [N_i, N_j, 3]
# rotate by R_i^T to get local direction
# rotations_i: [N_i, 3, 3], unit_diff: [N_i, N_j, 3]
# local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff) # [N_i, N_j, 3]
# --- Relative rotation: R_i^T R_j ---
# rotations_i: [N_i, 3, 3], rotations_j: [N_j, 3, 3]
# rel_rot[i,j] = R_i^T @ R_j -> [N_i, N_j, 3, 3] -> flatten to [N_i, N_j, 9]
rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j) # [N_i, N_j, 3, 3]
rel_rot_flat = rel_rot.reshape(N_i, N_j, 9) # [N_i, N_j, 9]
# --- Sequence separation ---
sep = seq_idx_j[None, :] - seq_idx_i[:, None] # [N_i, N_j]
# Encode each pair (loop over all; use vectorized bin assignment)
sep_flat = sep.reshape(-1)
sep_enc = np.array([seq_sep_encode(s) for s in sep_flat]) # [N_i*N_j, 8]
sep_enc = sep_enc.reshape(N_i, N_j, 8)
# Cross-chain pairs get sep=0 by convention if different chains
if chain_i != chain_j:
sep_enc[:] = 0.0
# --- Same chain indicator ---
same_chain = float(chain_i == chain_j)
same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32)
# --- Concatenate ---
edge_feats = np.concatenate([
dist_rbf, # [N_i, N_j, 16]
local_dir, # [N_i, N_j, 3]
rel_rot_flat, # [N_i, N_j, 9]
sep_enc, # [N_i, N_j, 8]
same_chain_feat # [N_i, N_j, 1]
], axis=-1) # [N_i, N_j, 37]
# Zero out edges involving masked residues
edge_feats[~mask_i, :, :] = 0.0
edge_feats[:, ~mask_j, :] = 0.0
return edge_feats.astype(np.float32)
def build_interface_graph(rec_residues, rec_coords, rec_mask,
binder_residues, binder_coords, binder_mask,
rec_interface_mask, binder_interface_mask,
max_nodes: int = 128):
"""
Build a joint interface graph combining receptor and binder interface residues.
Returns a dict with:
node_feats: [N_total, NODE_DIM]
edge_feats: [N_total, N_total, EDGE_DIM]
node_mask: [N_total] bool
n_rec: int (number of receptor interface nodes)
n_binder: int (number of binder interface nodes)
"""
# Select interface residues
rec_iface_idx = np.where(rec_interface_mask)[0]
binder_iface_idx = np.where(binder_interface_mask)[0]
# Truncate if too many
if len(rec_iface_idx) > max_nodes // 2:
rec_iface_idx = rec_iface_idx[:max_nodes // 2]
if len(binder_iface_idx) > max_nodes // 2:
binder_iface_idx = binder_iface_idx[:max_nodes // 2]
n_rec = len(rec_iface_idx)
n_binder = len(binder_iface_idx)
n_total = n_rec + n_binder
if n_total == 0:
return None
# Extract coords for interface residues
rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
rec_iface_mask = rec_mask[rec_iface_idx]
binder_iface_mask = binder_mask[binder_iface_idx]
# Compute backbone frames
rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask)
binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask)
# Compute torsion angles
# We need full-chain coords for proper phi/psi computation, but use local approximation here
rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask)
binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask)
# Extract residues
rec_iface_residues = [rec_residues[i] for i in rec_iface_idx]
binder_iface_residues = [binder_residues[i] for i in binder_iface_idx]
# Compute sidechain chi1/chi2 angles
rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask)
binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask)
# Node features
rec_node_feats = extract_node_features(
rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0
) # [n_rec, NODE_DIM]
binder_node_feats = extract_node_features(
binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1
) # [n_binder, NODE_DIM]
node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0) # [N, NODE_DIM]
node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0)
# Edge features (4 blocks: RR, RB, BR, BB)
all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0)
all_mask = node_mask
all_origins = np.concatenate([rec_origins, binder_origins], axis=0)
all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0)
all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0)
all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32)
# Compute full NxN edge features
frames_all = (all_origins, all_rotations)
edge_feats = extract_edge_features(
all_coords, frames_all,
all_coords, frames_all,
all_seq_idx, all_seq_idx,
-1, -1, # chain handled via all_chain array below
all_mask, all_mask
) # [N, N, EDGE_DIM]
# Patch same_chain feature (last dim) using actual chain IDs
same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32)
edge_feats[:, :, -1] = same_chain_feat
return {
'node_feats': node_feats.astype(np.float32), # [N, NODE_DIM]
'edge_feats': edge_feats.astype(np.float32), # [N, N, EDGE_DIM]
'node_mask': node_mask, # [N]
'n_rec': n_rec,
'n_binder': n_binder,
'rec_iface_idx': rec_iface_idx, # [n_rec] original residue indices
'binder_iface_idx': binder_iface_idx, # [n_binder] original residue indices
}