| """ |
| SE(3)-invariant feature extraction for interface graphs. |
| Node and edge features used by the Q_theta scorer. |
| """ |
|
|
| import os |
| import sys |
| import numpy as np |
|
|
| |
| _CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if _CODE_DIR not in sys.path: |
| sys.path.insert(0, _CODE_DIR) |
|
|
| from utils.pdb_utils import ( |
| rbf_encode, compute_backbone_frames, compute_torsion_angles, |
| get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA |
| ) |
|
|
| |
| |
| NODE_DIM = NUM_AA + 6 + 4 + 1 |
| EDGE_DIM = 16 + 3 + 9 + 8 + 1 |
|
|
| MAX_SEQ_SEP = 32 |
|
|
|
|
| def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP): |
| """Bin-encode sequence separation.""" |
| bins = np.linspace(-max_sep, max_sep, n_bins + 1) |
| sep_clipped = np.clip(sep, -max_sep, max_sep) |
| encoded = np.zeros(n_bins, dtype=np.float32) |
| bin_idx = np.digitize(sep_clipped, bins) - 1 |
| bin_idx = np.clip(bin_idx, 0, n_bins - 1) |
| encoded[bin_idx] = 1.0 |
| return encoded |
|
|
|
|
| def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id): |
| """ |
| Compute per-residue node features. |
| |
| Args: |
| residues: list of Bio.PDB residues |
| coords: [N, 4, 3] backbone coords |
| mask: [N] bool |
| torsion_angles: [N, 6] sin/cos of phi, psi, omega |
| chi_angles: [N, 4] sin/cos of chi1, chi2 |
| chain_id: 0 = receptor, 1 = binder |
| |
| Returns: |
| node_feats: [N, NODE_DIM] (NODE_DIM = 32) |
| """ |
| N = len(residues) |
| aa_idx = get_aa_indices(residues) |
|
|
| |
| aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32) |
| for i in range(N): |
| if mask[i]: |
| aa_onehot[i, aa_idx[i]] = 1.0 |
|
|
| |
| chain_feat = np.full((N, 1), chain_id, dtype=np.float32) |
|
|
| |
| node_feats = np.concatenate([ |
| aa_onehot, |
| torsion_angles, |
| chi_angles, |
| chain_feat, |
| ], axis=-1) |
|
|
| return node_feats |
|
|
|
|
| def extract_edge_features(coords_i, frames_i, coords_j, frames_j, |
| seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j): |
| """ |
| Compute SE(3)-invariant edge features between residue sets i and j. |
| Vectorized over all pairs. |
| |
| Args: |
| coords_i: [N_i, 4, 3] backbone coords of set i (full interface) |
| frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3]) |
| coords_j: [N_j, 4, 3] |
| frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3]) |
| seq_idx_i: [N_i] integer sequence indices (for sequence separation) |
| seq_idx_j: [N_j] integer sequence indices |
| chain_i: int (0 or 1) |
| chain_j: int (0 or 1) |
| mask_i: [N_i] bool |
| mask_j: [N_j] bool |
| |
| Returns: |
| edge_feats: [N_i, N_j, EDGE_DIM] |
| """ |
| N_i, N_j = len(coords_i), len(coords_j) |
| origins_i, rotations_i = frames_i |
| origins_j, rotations_j = frames_j |
|
|
| ca_i = origins_i |
| ca_j = origins_j |
|
|
| |
| diff = ca_j[None, :, :] - ca_i[:, None, :] |
| dist = np.sqrt((diff ** 2).sum(axis=-1)) |
| dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16) |
|
|
| |
| |
| unit_diff = diff / (dist[..., None] + 1e-8) |
| |
| |
| |
| local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff) |
|
|
| |
| |
| |
| rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j) |
| rel_rot_flat = rel_rot.reshape(N_i, N_j, 9) |
|
|
| |
| sep = seq_idx_j[None, :] - seq_idx_i[:, None] |
| |
| sep_flat = sep.reshape(-1) |
| sep_enc = np.array([seq_sep_encode(s) for s in sep_flat]) |
| sep_enc = sep_enc.reshape(N_i, N_j, 8) |
|
|
| |
| if chain_i != chain_j: |
| sep_enc[:] = 0.0 |
|
|
| |
| same_chain = float(chain_i == chain_j) |
| same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32) |
|
|
| |
| edge_feats = np.concatenate([ |
| dist_rbf, |
| local_dir, |
| rel_rot_flat, |
| sep_enc, |
| same_chain_feat |
| ], axis=-1) |
|
|
| |
| edge_feats[~mask_i, :, :] = 0.0 |
| edge_feats[:, ~mask_j, :] = 0.0 |
|
|
| return edge_feats.astype(np.float32) |
|
|
|
|
| def build_interface_graph(rec_residues, rec_coords, rec_mask, |
| binder_residues, binder_coords, binder_mask, |
| rec_interface_mask, binder_interface_mask, |
| max_nodes: int = 128): |
| """ |
| Build a joint interface graph combining receptor and binder interface residues. |
| |
| Returns a dict with: |
| node_feats: [N_total, NODE_DIM] |
| edge_feats: [N_total, N_total, EDGE_DIM] |
| node_mask: [N_total] bool |
| n_rec: int (number of receptor interface nodes) |
| n_binder: int (number of binder interface nodes) |
| """ |
| |
| rec_iface_idx = np.where(rec_interface_mask)[0] |
| binder_iface_idx = np.where(binder_interface_mask)[0] |
|
|
| |
| if len(rec_iface_idx) > max_nodes // 2: |
| rec_iface_idx = rec_iface_idx[:max_nodes // 2] |
| if len(binder_iface_idx) > max_nodes // 2: |
| binder_iface_idx = binder_iface_idx[:max_nodes // 2] |
|
|
| n_rec = len(rec_iface_idx) |
| n_binder = len(binder_iface_idx) |
| n_total = n_rec + n_binder |
|
|
| if n_total == 0: |
| return None |
|
|
| |
| rec_iface_coords = rec_coords[rec_iface_idx] |
| binder_iface_coords = binder_coords[binder_iface_idx] |
| rec_iface_mask = rec_mask[rec_iface_idx] |
| binder_iface_mask = binder_mask[binder_iface_idx] |
|
|
| |
| rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask) |
| binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask) |
|
|
| |
| |
| rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask) |
| binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask) |
|
|
| |
| rec_iface_residues = [rec_residues[i] for i in rec_iface_idx] |
| binder_iface_residues = [binder_residues[i] for i in binder_iface_idx] |
|
|
| |
| rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask) |
| binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask) |
|
|
| |
| rec_node_feats = extract_node_features( |
| rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0 |
| ) |
| binder_node_feats = extract_node_features( |
| binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1 |
| ) |
|
|
| node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0) |
| node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0) |
|
|
| |
| all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0) |
| all_mask = node_mask |
| all_origins = np.concatenate([rec_origins, binder_origins], axis=0) |
| all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0) |
| all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0) |
| all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32) |
|
|
| |
| frames_all = (all_origins, all_rotations) |
| edge_feats = extract_edge_features( |
| all_coords, frames_all, |
| all_coords, frames_all, |
| all_seq_idx, all_seq_idx, |
| -1, -1, |
| all_mask, all_mask |
| ) |
|
|
| |
| same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32) |
| edge_feats[:, :, -1] = same_chain_feat |
|
|
| return { |
| 'node_feats': node_feats.astype(np.float32), |
| 'edge_feats': edge_feats.astype(np.float32), |
| 'node_mask': node_mask, |
| 'n_rec': n_rec, |
| 'n_binder': n_binder, |
| 'rec_iface_idx': rec_iface_idx, |
| 'binder_iface_idx': binder_iface_idx, |
| } |
|
|