""" PDB parsing utilities for Allo-Designer. Extracts backbone geometry, computes local frames, and identifies interface residues. """ import numpy as np from Bio import PDB from Bio.PDB import PDBParser, MMCIFParser, PDBIO from Bio.PDB.Polypeptide import is_aa import warnings warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning) AA3_TO_IDX = { 'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4, 'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9, 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14, 'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19, 'UNK': 20, } NUM_AA = 21 # 20 standard + UNK def load_structure(pdb_path: str, model_id: int = 0): """Load a PDB/CIF file and return the first model.""" if pdb_path.endswith('.cif') or pdb_path.endswith('.mmcif'): parser = MMCIFParser(QUIET=True) else: parser = PDBParser(QUIET=True) struct = parser.get_structure("protein", pdb_path) return list(struct.get_models())[model_id] def get_residues(chain, only_standard: bool = True): """Return a list of standard amino acid residues from a chain.""" residues = [] for res in chain.get_residues(): if only_standard and not is_aa(res, standard=True): continue if res.get_id()[0] != ' ': # skip HETATM continue residues.append(res) return residues def get_backbone_coords(residues): """ Extract backbone atom coordinates (N, CA, C, O) for each residue. Returns: coords [N_res, 4, 3], mask [N_res] (True = all backbone atoms present) """ N = len(residues) coords = np.zeros((N, 4, 3), dtype=np.float32) mask = np.zeros(N, dtype=bool) for i, res in enumerate(residues): try: coords[i, 0] = res['N'].get_vector().get_array() coords[i, 1] = res['CA'].get_vector().get_array() coords[i, 2] = res['C'].get_vector().get_array() if 'O' in res: coords[i, 3] = res['O'].get_vector().get_array() else: # Estimate O position if missing coords[i, 3] = coords[i, 2] mask[i] = True except KeyError: pass return coords, mask def get_aa_indices(residues): """Return integer amino acid indices for each residue.""" return np.array([ AA3_TO_IDX.get(res.get_resname(), AA3_TO_IDX['UNK']) for res in residues ], dtype=np.int64) def compute_backbone_frames(coords, mask): """ Compute SE(3)-equivariant backbone frames from N, CA, C atoms. Frame: z-axis = CA->C, y-axis = component of CA->N perpendicular to z, x-axis = y x z. Returns: origins: [N, 3] = CA positions rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes) """ N_res = coords.shape[0] origins = coords[:, 1, :] # CA positions [N, 3] rotations = np.zeros((N_res, 3, 3), dtype=np.float32) for i in range(N_res): if not mask[i]: rotations[i] = np.eye(3) continue ca = coords[i, 1] n = coords[i, 0] c = coords[i, 2] # z-axis: CA -> C z = c - ca z_norm = np.linalg.norm(z) if z_norm < 1e-6: rotations[i] = np.eye(3) continue z = z / z_norm # y-axis: CA -> N, orthogonalized y = n - ca y = y - np.dot(y, z) * z y_norm = np.linalg.norm(y) if y_norm < 1e-6: rotations[i] = np.eye(3) continue y = y / y_norm # x-axis: y cross z x = np.cross(y, z) rotations[i] = np.stack([x, y, z], axis=-1) # columns are axes return origins, rotations def compute_torsion_angles(coords, mask): """ Compute backbone torsion angles (phi, psi, omega) for each residue. Returns sin/cos of each angle. [N, 6] """ N = len(coords) angles = np.zeros((N, 6), dtype=np.float32) def dihedral(p0, p1, p2, p3): """Praxelis dihedral angle computation.""" b1 = p1 - p0 b2 = p2 - p1 b3 = p3 - p2 n1 = np.cross(b1, b2) n2 = np.cross(b2, b3) n1_norm = np.linalg.norm(n1) n2_norm = np.linalg.norm(n2) if n1_norm < 1e-6 or n2_norm < 1e-6: return 0.0 n1 = n1 / n1_norm n2 = n2 / n2_norm m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8)) cos_a = np.clip(np.dot(n1, n2), -1, 1) sin_a = np.dot(m1, n2) return np.arctan2(sin_a, cos_a) for i in range(N): if not mask[i]: continue ca_i = coords[i, 1] n_i = coords[i, 0] c_i = coords[i, 2] # Phi: C_{i-1} - N_i - CA_i - C_i if i > 0 and mask[i - 1]: c_prev = coords[i - 1, 2] phi = dihedral(c_prev, n_i, ca_i, c_i) angles[i, 0] = np.sin(phi) angles[i, 1] = np.cos(phi) # Psi: N_i - CA_i - C_i - N_{i+1} if i < N - 1 and mask[i + 1]: n_next = coords[i + 1, 0] psi = dihedral(n_i, ca_i, c_i, n_next) angles[i, 2] = np.sin(psi) angles[i, 3] = np.cos(psi) # Omega: CA_{i-1} - C_{i-1} - N_i - CA_i if i > 0 and mask[i - 1]: ca_prev = coords[i - 1, 1] c_prev = coords[i - 1, 2] omega = dihedral(ca_prev, c_prev, n_i, ca_i) angles[i, 4] = np.sin(omega) angles[i, 5] = np.cos(omega) return angles def get_interface_residues(rec_coords, binder_coords, rec_mask, binder_mask, cutoff: float = 8.0): """ Find interface residues: receptor residues within cutoff of any binder Cα, and vice versa. Uses CA-CA distances. Returns: rec_interface: bool array [N_rec] binder_interface: bool array [N_binder] """ rec_ca = rec_coords[:, 1, :] # [N_rec, 3] binder_ca = binder_coords[:, 1, :] # [N_binder, 3] # Pairwise CA-CA distances [N_rec, N_binder] diff = rec_ca[:, None, :] - binder_ca[None, :, :] # [N_rec, N_binder, 3] dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_rec, N_binder] # Mask out residues without coordinates dist[~rec_mask, :] = np.inf dist[:, ~binder_mask] = np.inf rec_interface = (dist < cutoff).any(axis=1) binder_interface = (dist < cutoff).any(axis=0) return rec_interface, binder_interface def align_structures(mobile_ca, ref_ca, mobile_coords=None): """ Kabsch alignment: align mobile to ref using CA positions. Returns aligned CA coords and optionally full backbone coords. """ assert mobile_ca.shape == ref_ca.shape, "Must have same number of residues" # Center mobile_center = mobile_ca.mean(axis=0) ref_center = ref_ca.mean(axis=0) m = mobile_ca - mobile_center r = ref_ca - ref_center # SVD H = m.T @ r U, S, Vt = np.linalg.svd(H) d = np.sign(np.linalg.det(Vt.T @ U.T)) D = np.diag([1, 1, d]) R = Vt.T @ D @ U.T # rotation matrix mobile_ca_aligned = (m @ R.T) + ref_center if mobile_coords is not None: # Apply same rotation to full backbone N_res, N_atoms, _ = mobile_coords.shape flat = mobile_coords.reshape(-1, 3) - mobile_center aligned_flat = (flat @ R.T) + ref_center mobile_coords_aligned = aligned_flat.reshape(N_res, N_atoms, 3) return mobile_ca_aligned, R, mobile_coords_aligned return mobile_ca_aligned, R def compute_ca_rmsd(coords1, coords2, mask=None): """Compute CA-RMSD between two sets of backbone coordinates.""" ca1 = coords1[:, 1, :] ca2 = coords2[:, 1, :] if mask is not None: ca1 = ca1[mask] ca2 = ca2[mask] diff = ca1 - ca2 return np.sqrt((diff ** 2).sum(axis=-1).mean()) def compute_fraction_native_contacts( native_rec_ca, native_binder_ca, model_rec_ca=None, model_binder_ca=None, cutoff=8.0, # Legacy 2-arg signature support mask=None, delta=1.0, ): """ Compute fraction of native inter-chain contacts (fNAT). fNAT = |recovered inter-chain contacts| / |native inter-chain contacts| A native contact is a (receptor_i, binder_j) pair with CA-CA distance < cutoff in the native complex. A contact is "recovered" if the same pair is < cutoff in the model complex. Args: native_rec_ca: [N_rec, 3] receptor CA coords in native complex native_binder_ca: [N_bind, 3] binder CA coords in native complex model_rec_ca: [N_rec, 3] receptor CA in model (default: same as native) model_binder_ca: [N_bind, 3] binder CA in model (default: same as native) cutoff: contact distance threshold in Angstroms (default 8.0 for CA-CA) Returns: fNAT in [0, 1]. Returns 0.0 if no native contacts exist. """ if model_rec_ca is None: model_rec_ca = native_rec_ca if model_binder_ca is None: model_binder_ca = native_binder_ca # Inter-chain distance matrices [N_rec, N_bind] native_dist = np.sqrt( ((native_rec_ca[:, None, :] - native_binder_ca[None, :, :]) ** 2).sum(-1) ) model_dist = np.sqrt( ((model_rec_ca[:, None, :] - model_binder_ca[None, :, :]) ** 2).sum(-1) ) native_contacts = native_dist < cutoff recovered = native_contacts & (model_dist < cutoff) n_native = native_contacts.sum() if n_native == 0: return 0.0 return float(recovered.sum()) / float(n_native) def rbf_encode(distances, d_min=0.0, d_max=20.0, n_bins=16): """ RBF encoding of distances using Gaussian basis functions. Returns: [*distances.shape, n_bins] """ centers = np.linspace(d_min, d_max, n_bins) sigma = (d_max - d_min) / (n_bins - 1) encoded = np.exp(-((distances[..., None] - centers) ** 2) / (2 * sigma ** 2)) return encoded.astype(np.float32) # Candidate sidechain atoms for chi1 (first atom after CB) _CHI1_ATOMS = ['CG', 'CG1', 'OG', 'OG1', 'SG'] # Candidate sidechain atoms for chi2 (second dihedral: CA-CB-XG-XD) _CHI2_ATOMS = ['CD', 'CD1', 'SD', 'OD1', 'ND1', 'CE', 'NE', 'OE1'] def _dihedral_4pts(p0, p1, p2, p3): """Compute dihedral angle between four 3D points (radians).""" b1 = p1 - p0 b2 = p2 - p1 b3 = p3 - p2 n1 = np.cross(b1, b2) n2 = np.cross(b2, b3) n1_norm = np.linalg.norm(n1) n2_norm = np.linalg.norm(n2) if n1_norm < 1e-6 or n2_norm < 1e-6: return 0.0 n1 = n1 / n1_norm n2 = n2 / n2_norm m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8)) return np.arctan2(np.dot(m1, n2), np.dot(n1, n2)) def compute_chi_angles(residues, mask): """ Compute chi1 and chi2 sidechain torsion angles for each residue. Chi1: N - CA - CB - XG (first sidechain dihedral) Chi2: CA - CB - XG - XD (second sidechain dihedral) For residues lacking the atoms (Gly, or missing coordinates), returns zeros. Returns: chi_feats: [N, 4] (sin_chi1, cos_chi1, sin_chi2, cos_chi2) """ N = len(residues) chi_feats = np.zeros((N, 4), dtype=np.float32) for i, res in enumerate(residues): if not mask[i]: continue atoms = {atom.get_name(): atom.get_vector().get_array() for atom in res.get_atoms() if atom.get_name() in ('N', 'CA', 'CB') + tuple(_CHI1_ATOMS) + tuple(_CHI2_ATOMS)} n_pos = atoms.get('N') ca_pos = atoms.get('CA') cb_pos = atoms.get('CB') if n_pos is None or ca_pos is None or cb_pos is None: continue # Chi1: N - CA - CB - XG xg_pos = None for aname in _CHI1_ATOMS: if aname in atoms: xg_pos = atoms[aname] break if xg_pos is not None: chi1 = _dihedral_4pts(np.array(n_pos), np.array(ca_pos), np.array(cb_pos), np.array(xg_pos)) chi_feats[i, 0] = np.sin(chi1) chi_feats[i, 1] = np.cos(chi1) # Chi2: CA - CB - XG - XD xd_pos = None for aname in _CHI2_ATOMS: if aname in atoms: xd_pos = atoms[aname] break if xd_pos is not None: chi2 = _dihedral_4pts(np.array(ca_pos), np.array(cb_pos), np.array(xg_pos), np.array(xd_pos)) chi_feats[i, 2] = np.sin(chi2) chi_feats[i, 3] = np.cos(chi2) return chi_feats def get_cb_positions(residues, coords, mask): """ Return CB positions for each residue (CA position for Gly or missing CB). Returns: cb_pos: [N, 3] """ N = len(residues) cb_pos = coords[:, 1, :].copy() # default to CA for i, res in enumerate(residues): if not mask[i]: continue try: cb_pos[i] = res['CB'].get_vector().get_array() except KeyError: pass # Gly or missing CB: keep CA return cb_pos.astype(np.float32) # Simplified hydrophobicity groups for contact energy _HYDROPHOBIC = {'ALA', 'VAL', 'ILE', 'LEU', 'MET', 'PHE', 'TRP', 'PRO', 'TYR'} _POS_CHARGED = {'ARG', 'LYS', 'HIS'} _NEG_CHARGED = {'ASP', 'GLU'} def _residue_group(resname): if resname in _HYDROPHOBIC: return 'H' if resname in _POS_CHARGED: return '+' if resname in _NEG_CHARGED: return '-' return 'P' # polar def compute_contact_energy(rec_residues, binder_residues, rec_cb, binder_cb, rec_mask, binder_mask, cutoff: float = 8.0): """ Compute a simple CB-CB contact energy as a physics-based ddG proxy. Uses a 4-group hydrophobicity potential: HH: -1.0 (hydrophobic-hydrophobic, favorable) +-: -0.5 (opposite charges, favorable) H+/-: +0.3 (hydrophobic-charged, unfavorable) else: 0.0 Returns a scalar in [0, 1] via sigmoid normalization. """ n_rec = len(rec_residues) n_binder = len(binder_residues) # CB-CB distance matrix [n_rec, n_binder] diff = rec_cb[:, None, :] - binder_cb[None, :, :] # [n_rec, n_binder, 3] dist = np.sqrt((diff ** 2).sum(axis=-1)) # [n_rec, n_binder] # Mask invalid residues dist[~rec_mask, :] = np.inf dist[:, ~binder_mask] = np.inf contact_mask = dist < cutoff energy = 0.0 for i in range(n_rec): for j in range(n_binder): if not contact_mask[i, j]: continue gi = _residue_group(rec_residues[i].get_resname()) gj = _residue_group(binder_residues[j].get_resname()) if gi == 'H' and gj == 'H': energy -= 1.0 elif (gi == '+' and gj == '-') or (gi == '-' and gj == '+'): energy -= 0.5 elif (gi == 'H' and gj in ('+', '-')) or (gj == 'H' and gi in ('+', '-')): energy += 0.3 # Normalize: sigmoid of (energy / 10) shifted so that 0 contacts → score 0.3 score = 1.0 / (1.0 + np.exp(-(energy - 5.0) / 5.0)) return float(score)