lighteternal commited on
Commit
34195f1
·
verified ·
1 Parent(s): 6ee95bc

Add self-contained inference helper and clarify score semantics

Browse files
Files changed (3) hide show
  1. README.md +81 -17
  2. bioassayalign_compatibility.py +439 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -57,6 +57,12 @@ This artifact uses:
57
 
58
  The final score comes from the learned compatibility head. It is not just a raw embedding dot product.
59
 
 
 
 
 
 
 
60
  ## Training Data
61
 
62
  Training uses a frozen public assay-compound corpus derived from:
@@ -137,39 +143,97 @@ Candidate list ranked by the model:
137
 
138
  The raw values above are model scores. In practice, read them as list-relative ranking values, not calibrated probabilities.
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  ## How To Run It Locally
141
 
142
  ### Minimal local check from this repo
143
 
144
- This downloads only the model weights and metadata, not the raw assay dataset.
145
 
146
  ```bash
147
- MODEL_REPO_ID='lighteternal/BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility' \
148
- LOCAL_MODEL_DIR='data/hf_compat_model_check' \
149
- bash scripts/score_compatibility_from_hf.sh \
150
- --assay-title 'JAK2 inhibition assay' \
151
- --description 'Cell-based luminescence assay measuring JAK2 inhibition in HEK293 cells.' \
152
- --organism 'Homo sapiens' \
153
- --readout 'luminescence' \
154
- --assay-format 'cell-based' \
155
- --assay-type 'inhibition' \
156
- --target-uniprot O60674 \
157
- --smiles 'CC(=O)Nc1ncc(C#N)c(Nc2ccc(F)c(Cl)c2)n1' \
158
- --smiles 'c1ccccc1' \
159
- --smiles 'CCO'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  ```
161
 
162
  ### Python usage
163
 
164
  ```python
165
- from bioassayalign.compat_inference import (
166
  AssayQuery,
167
- load_compatibility_model,
168
  rank_compounds,
169
  serialize_assay_query,
170
  )
171
 
172
- model = load_compatibility_model("/path/to/model_dir")
 
 
173
  assay_text = serialize_assay_query(
174
  AssayQuery(
175
  title="JAK2 inhibition assay",
 
57
 
58
  The final score comes from the learned compatibility head. It is not just a raw embedding dot product.
59
 
60
+ This repo also includes a self-contained inference helper:
61
+ - `bioassayalign_compatibility.py`
62
+ - `requirements.txt`
63
+
64
+ You do not need the project GitHub repo to run the published model.
65
+
66
  ## Training Data
67
 
68
  Training uses a frozen public assay-compound corpus derived from:
 
143
 
144
  The raw values above are model scores. In practice, read them as list-relative ranking values, not calibrated probabilities.
145
 
146
+ ## What The Raw Score Means
147
+
148
+ The raw output is a learned ranking score from the compatibility head.
149
+
150
+ You can think of it as a **logit-like utility value**:
151
+ - higher is better
152
+ - differences inside the same submitted list matter
153
+ - absolute values across unrelated lists are not directly comparable
154
+
155
+ Example:
156
+ - a top candidate with score `-4.7`
157
+ - another candidate with score `-20.0`
158
+
159
+ does **not** mean the first compound has negative biological value. It only means the first item scored much better than the second one for that submitted assay-and-list context.
160
+
161
+ If you want a normalized shortlist view for one submitted list, you can convert the raw scores with:
162
+ - a min-max `0–100` relative ranking scale, or
163
+ - a softmax over the submitted list
164
+
165
+ Softmax example for one list:
166
+
167
+ ```python
168
+ from bioassayalign_compatibility import list_softmax_scores
169
+
170
+ scores = [-4.6947, -15.0503, -20.0474]
171
+ relative_probs = list_softmax_scores(scores)
172
+ print(relative_probs)
173
+ ```
174
+
175
+ Important:
176
+ - those softmax values are **not calibrated probabilities of success**
177
+ - they are only a normalized way to compare the candidates you submitted together
178
+
179
  ## How To Run It Locally
180
 
181
  ### Minimal local check from this repo
182
 
183
+ This downloads only the published model files from this repo, not the raw assay dataset.
184
 
185
  ```bash
186
+ python -m pip install -r requirements.txt
187
+ python - <<'PY'
188
+ from bioassayalign_compatibility import (
189
+ AssayQuery,
190
+ load_compatibility_model_from_hub,
191
+ rank_compounds,
192
+ serialize_assay_query,
193
+ )
194
+
195
+ model = load_compatibility_model_from_hub(
196
+ "lighteternal/BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility"
197
+ )
198
+ assay_text = serialize_assay_query(
199
+ AssayQuery(
200
+ title="JAK2 inhibition assay",
201
+ description="Cell-based luminescence assay measuring JAK2 inhibition in HEK293 cells.",
202
+ organism="Homo sapiens",
203
+ readout="luminescence",
204
+ assay_format="cell-based",
205
+ assay_type="inhibition",
206
+ target_uniprot=["O60674"],
207
+ )
208
+ )
209
+
210
+ results = rank_compounds(
211
+ model,
212
+ assay_text=assay_text,
213
+ smiles_list=[
214
+ "CC(=O)Nc1ncc(C#N)c(Nc2ccc(F)c(Cl)c2)n1",
215
+ "c1ccccc1",
216
+ "CCO",
217
+ ],
218
+ )
219
+ for row in results:
220
+ print(row)
221
+ PY
222
  ```
223
 
224
  ### Python usage
225
 
226
  ```python
227
+ from bioassayalign_compatibility import (
228
  AssayQuery,
229
+ load_compatibility_model_from_hub,
230
  rank_compounds,
231
  serialize_assay_query,
232
  )
233
 
234
+ model = load_compatibility_model_from_hub(
235
+ "lighteternal/BioAssayAlign-Qwen3-Embedding-0.6B-Compatibility"
236
+ )
237
  assay_text = serialize_assay_query(
238
  AssayQuery(
239
  title="JAK2 inhibition assay",
bioassayalign_compatibility.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from functools import lru_cache
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from huggingface_hub import snapshot_download
16
+ from huggingface_hub.utils import disable_progress_bars
17
+ from rdkit import Chem, DataStructs, RDLogger
18
+ from rdkit.Chem import AllChem, Crippen, Descriptors, Lipinski, MACCSkeys, rdMolDescriptors
19
+ from rdkit.Chem.MolStandardize import rdMolStandardize
20
+ from sentence_transformers import SentenceTransformer
21
+ from torch import nn
22
+ from transformers.utils import logging as transformers_logging
23
+
24
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
25
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
26
+ disable_progress_bars()
27
+ transformers_logging.set_verbosity_error()
28
+ RDLogger.DisableLog("rdApp.*")
29
+
30
+ DEFAULT_ASSAY_TASK = (
31
+ "Given a bioassay description and metadata, represent the assay for ranking compatible small molecules."
32
+ )
33
+ SECTION_ORDER = [
34
+ "ASSAY_TITLE",
35
+ "DESCRIPTION",
36
+ "ORGANISM",
37
+ "READOUT",
38
+ "ASSAY_FORMAT",
39
+ "ASSAY_TYPE",
40
+ "TARGET_UNIPROT",
41
+ ]
42
+ ASSAY_SECTION_RE = re.compile(r"\[(ASSAY_TITLE|DESCRIPTION|ORGANISM|READOUT|ASSAY_FORMAT|ASSAY_TYPE|TARGET_UNIPROT)\]\n")
43
+ ORGANISM_ALIASES = {
44
+ "9606": "homo_sapiens",
45
+ "10090": "mus_musculus",
46
+ "10116": "rattus_norvegicus",
47
+ "4932": "saccharomyces_cerevisiae",
48
+ }
49
+ DEFAULT_DESCRIPTOR_NAMES = (
50
+ "mol_wt",
51
+ "logp",
52
+ "tpsa",
53
+ "heavy_atoms",
54
+ "hbd",
55
+ "hba",
56
+ "rot_bonds",
57
+ "ring_count",
58
+ "aromatic_rings",
59
+ "aliphatic_rings",
60
+ "saturated_rings",
61
+ "fraction_csp3",
62
+ "heteroatoms",
63
+ "amide_bonds",
64
+ "fragments",
65
+ "formal_charge",
66
+ "max_atomic_num",
67
+ "metal_atom_count",
68
+ "halogen_count",
69
+ "nitrogen_count",
70
+ "oxygen_count",
71
+ "sulfur_count",
72
+ "phosphorus_count",
73
+ "fluorine_count",
74
+ "chlorine_count",
75
+ "bromine_count",
76
+ "iodine_count",
77
+ "aromatic_atom_count",
78
+ "spiro_atoms",
79
+ "bridgehead_atoms",
80
+ )
81
+ ORGANIC_LIKE_ATOMIC_NUMBERS = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53}
82
+
83
+
84
+ @dataclass
85
+ class AssayQuery:
86
+ title: str = ""
87
+ description: str = ""
88
+ organism: str = ""
89
+ readout: str = ""
90
+ assay_format: str = ""
91
+ assay_type: str = ""
92
+ target_uniprot: list[str] | None = None
93
+
94
+
95
+ def serialize_assay_query(query: AssayQuery) -> str:
96
+ targets = ", ".join(query.target_uniprot or [])
97
+ values = {
98
+ "ASSAY_TITLE": query.title.strip(),
99
+ "DESCRIPTION": query.description.strip(),
100
+ "ORGANISM": query.organism.strip(),
101
+ "READOUT": query.readout.strip(),
102
+ "ASSAY_FORMAT": query.assay_format.strip(),
103
+ "ASSAY_TYPE": query.assay_type.strip(),
104
+ "TARGET_UNIPROT": targets.strip(),
105
+ }
106
+ return "\n\n".join(f"[{key}]\n{values[key]}" for key in SECTION_ORDER)
107
+
108
+
109
+ def _parse_assay_sections(assay_text: str) -> dict[str, str]:
110
+ sections = {key: "" for key in SECTION_ORDER}
111
+ parts = ASSAY_SECTION_RE.split(assay_text)
112
+ for idx in range(1, len(parts), 2):
113
+ key = parts[idx]
114
+ value = parts[idx + 1] if idx + 1 < len(parts) else ""
115
+ if key in sections:
116
+ sections[key] = value.strip()
117
+ return sections
118
+
119
+
120
+ def _normalize_metadata_token(value: str) -> str:
121
+ return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_")
122
+
123
+
124
+ def _normalize_organism_token(value: str) -> str:
125
+ raw = value.strip()
126
+ if not raw:
127
+ return ""
128
+ aliased = ORGANISM_ALIASES.get(raw, raw)
129
+ return _normalize_metadata_token(aliased)
130
+
131
+
132
+ def _hash_bucket(value: str, dim: int) -> int:
133
+ return abs(hash(value)) % max(dim, 1)
134
+
135
+
136
+ def _assay_metadata_vector(assay_text: str, *, dim: int) -> np.ndarray:
137
+ if dim <= 0:
138
+ return np.zeros((0,), dtype=np.float32)
139
+ sections = _parse_assay_sections(assay_text)
140
+ tokens: list[str] = []
141
+ organism = _normalize_organism_token(sections.get("ORGANISM", ""))
142
+ if organism:
143
+ tokens.append(f"organism:{organism}")
144
+ for key in ("READOUT", "ASSAY_FORMAT", "ASSAY_TYPE"):
145
+ value = _normalize_metadata_token(sections.get(key, ""))
146
+ if value:
147
+ tokens.append(f"{key.lower()}:{value}")
148
+ for target in sections.get("TARGET_UNIPROT", "").split(","):
149
+ token = target.strip().upper()
150
+ if token:
151
+ tokens.append(f"target:{token}")
152
+ vec = np.zeros((dim,), dtype=np.float32)
153
+ for token in tokens:
154
+ vec[_hash_bucket(token, dim)] += 1.0
155
+ norm = float(np.linalg.norm(vec))
156
+ if norm > 0:
157
+ vec /= norm
158
+ return vec
159
+
160
+
161
+ @lru_cache(maxsize=1_000_000)
162
+ def _standardize_smiles_v2_cached(smiles: str) -> str | None:
163
+ mol = Chem.MolFromSmiles(smiles)
164
+ if mol is None:
165
+ return None
166
+ try:
167
+ mol = rdMolStandardize.Cleanup(mol)
168
+ mol = rdMolStandardize.FragmentParent(mol)
169
+ mol = rdMolStandardize.Uncharger().uncharge(mol)
170
+ mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
171
+ Chem.SanitizeMol(mol)
172
+ except Exception:
173
+ return None
174
+ if mol.GetNumHeavyAtoms() < 2:
175
+ return None
176
+ standardized = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
177
+ if not standardized or "." in standardized:
178
+ return None
179
+ return standardized
180
+
181
+
182
+ def standardize_smiles_v2(smiles: str | None) -> str | None:
183
+ if not smiles:
184
+ return None
185
+ token = smiles.strip()
186
+ if not token:
187
+ return None
188
+ return _standardize_smiles_v2_cached(token)
189
+
190
+
191
+ def smiles_sha256(smiles: str) -> str:
192
+ return hashlib.sha256(smiles.encode("utf-8")).hexdigest()
193
+
194
+
195
+ def _count_atomic_nums(mol) -> dict[int, int]:
196
+ counts: dict[int, int] = {}
197
+ for atom in mol.GetAtoms():
198
+ atomic_num = int(atom.GetAtomicNum())
199
+ counts[atomic_num] = counts.get(atomic_num, 0) + 1
200
+ return counts
201
+
202
+
203
+ def _morgan_bits_from_mol(mol, *, radius: int, n_bits: int, use_chirality: bool) -> np.ndarray:
204
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits, useChirality=use_chirality)
205
+ arr = np.zeros((n_bits,), dtype=np.uint8)
206
+ DataStructs.ConvertToNumpyArray(fp, arr)
207
+ return arr
208
+
209
+
210
+ def _maccs_bits_from_mol(mol) -> np.ndarray:
211
+ fp = MACCSkeys.GenMACCSKeys(mol)
212
+ arr = np.zeros((fp.GetNumBits(),), dtype=np.uint8)
213
+ DataStructs.ConvertToNumpyArray(fp, arr)
214
+ return arr
215
+
216
+
217
+ def _molecule_descriptor_vector(mol, *, names: tuple[str, ...] = DEFAULT_DESCRIPTOR_NAMES) -> np.ndarray:
218
+ counts = _count_atomic_nums(mol)
219
+ fragments = Chem.GetMolFrags(mol)
220
+ formal_charge = sum(int(atom.GetFormalCharge()) for atom in mol.GetAtoms())
221
+ max_atomic_num = max(counts) if counts else 0
222
+ metal_atom_count = sum(count for atomic_num, count in counts.items() if atomic_num not in ORGANIC_LIKE_ATOMIC_NUMBERS)
223
+ halogen_count = sum(counts.get(item, 0) for item in (9, 17, 35, 53))
224
+ aromatic_atom_count = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
225
+ values = {
226
+ "mol_wt": float(Descriptors.MolWt(mol)),
227
+ "logp": float(Crippen.MolLogP(mol)),
228
+ "tpsa": float(rdMolDescriptors.CalcTPSA(mol)),
229
+ "heavy_atoms": float(mol.GetNumHeavyAtoms()),
230
+ "hbd": float(Lipinski.NumHDonors(mol)),
231
+ "hba": float(Lipinski.NumHAcceptors(mol)),
232
+ "rot_bonds": float(Lipinski.NumRotatableBonds(mol)),
233
+ "ring_count": float(rdMolDescriptors.CalcNumRings(mol)),
234
+ "aromatic_rings": float(rdMolDescriptors.CalcNumAromaticRings(mol)),
235
+ "aliphatic_rings": float(rdMolDescriptors.CalcNumAliphaticRings(mol)),
236
+ "saturated_rings": float(rdMolDescriptors.CalcNumSaturatedRings(mol)),
237
+ "fraction_csp3": float(rdMolDescriptors.CalcFractionCSP3(mol)),
238
+ "heteroatoms": float(rdMolDescriptors.CalcNumHeteroatoms(mol)),
239
+ "amide_bonds": float(rdMolDescriptors.CalcNumAmideBonds(mol)),
240
+ "fragments": float(len(fragments)),
241
+ "formal_charge": float(formal_charge),
242
+ "max_atomic_num": float(max_atomic_num),
243
+ "metal_atom_count": float(metal_atom_count),
244
+ "halogen_count": float(halogen_count),
245
+ "nitrogen_count": float(counts.get(7, 0)),
246
+ "oxygen_count": float(counts.get(8, 0)),
247
+ "sulfur_count": float(counts.get(16, 0)),
248
+ "phosphorus_count": float(counts.get(15, 0)),
249
+ "fluorine_count": float(counts.get(9, 0)),
250
+ "chlorine_count": float(counts.get(17, 0)),
251
+ "bromine_count": float(counts.get(35, 0)),
252
+ "iodine_count": float(counts.get(53, 0)),
253
+ "aromatic_atom_count": float(aromatic_atom_count),
254
+ "spiro_atoms": float(rdMolDescriptors.CalcNumSpiroAtoms(mol)),
255
+ "bridgehead_atoms": float(rdMolDescriptors.CalcNumBridgeheadAtoms(mol)),
256
+ }
257
+ return np.asarray([values[name] for name in names], dtype=np.float32)
258
+
259
+
260
+ class CompatibilityHead(nn.Module):
261
+ def __init__(
262
+ self,
263
+ *,
264
+ assay_dim: int,
265
+ molecule_dim: int,
266
+ projection_dim: int,
267
+ hidden_dim: int,
268
+ dropout: float,
269
+ metadata_dim: int = 0,
270
+ ) -> None:
271
+ super().__init__()
272
+ self.metadata_dim = metadata_dim
273
+ assay_input_dim = assay_dim + metadata_dim
274
+ self.assay_proj = nn.Sequential(
275
+ nn.Linear(assay_input_dim, projection_dim),
276
+ nn.GELU(),
277
+ nn.Dropout(dropout),
278
+ )
279
+ self.molecule_proj = nn.Sequential(
280
+ nn.Linear(molecule_dim, projection_dim),
281
+ nn.GELU(),
282
+ nn.Dropout(dropout),
283
+ )
284
+ self.scorer = nn.Sequential(
285
+ nn.Linear(projection_dim * 4, hidden_dim),
286
+ nn.GELU(),
287
+ nn.Dropout(dropout),
288
+ nn.Linear(hidden_dim, 1),
289
+ )
290
+
291
+ def forward(self, assay_vec: torch.Tensor, molecule_vec: torch.Tensor, assay_metadata: torch.Tensor | None = None) -> torch.Tensor:
292
+ if assay_metadata is not None and assay_metadata.numel():
293
+ assay_input = torch.cat([assay_vec, assay_metadata], dim=-1)
294
+ else:
295
+ assay_input = assay_vec
296
+ assay_hidden = self.assay_proj(assay_input)
297
+ molecule_hidden = self.molecule_proj(molecule_vec)
298
+ interaction = torch.cat(
299
+ [
300
+ assay_hidden,
301
+ molecule_hidden,
302
+ assay_hidden * molecule_hidden,
303
+ torch.abs(assay_hidden - molecule_hidden),
304
+ ],
305
+ dim=-1,
306
+ )
307
+ return self.scorer(interaction).squeeze(-1)
308
+
309
+
310
+ class CompatibilityModel:
311
+ def __init__(self, assay_encoder: SentenceTransformer, metadata: dict[str, Any], head_state: dict[str, Any], *, device: str | None = None) -> None:
312
+ self.metadata = metadata
313
+ self.config = metadata["config"]
314
+ self.feature_spec = metadata["molecule_feature_spec"]
315
+ self.metadata_dim = int(self.config.get("assay_metadata_dim", 0))
316
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
317
+ self.assay_encoder = assay_encoder
318
+ self.assay_encoder.max_seq_length = 512
319
+ self.assay_dim = int(self.assay_encoder.get_sentence_embedding_dimension())
320
+ self.molecule_dim = int(metadata["feature_counts"]["molecule_dim"])
321
+ self.head = CompatibilityHead(
322
+ assay_dim=self.assay_dim,
323
+ molecule_dim=self.molecule_dim,
324
+ projection_dim=int(self.config["projection_dim"]),
325
+ hidden_dim=int(self.config["hidden_dim"]),
326
+ dropout=float(self.config["dropout"]),
327
+ metadata_dim=self.metadata_dim,
328
+ ).to(self.device)
329
+ self.head.load_state_dict(head_state)
330
+ self.head.eval()
331
+
332
+ def encode_assay(self, assay_text: str) -> tuple[torch.Tensor, torch.Tensor | None]:
333
+ embedding = self.assay_encoder.encode(
334
+ [assay_text],
335
+ convert_to_numpy=True,
336
+ show_progress_bar=False,
337
+ normalize_embeddings=True,
338
+ prompt_name="query",
339
+ prompt=DEFAULT_ASSAY_TASK,
340
+ )[0].astype(np.float32)
341
+ assay_vec = torch.from_numpy(embedding).unsqueeze(0).to(self.device)
342
+ metadata_vec = _assay_metadata_vector(assay_text, dim=self.metadata_dim)
343
+ metadata_tensor = None
344
+ if metadata_vec.size:
345
+ metadata_tensor = torch.from_numpy(metadata_vec).unsqueeze(0).to(self.device)
346
+ return assay_vec, metadata_tensor
347
+
348
+ def score_feature_matrix(self, assay_text: str, feature_matrix: np.ndarray) -> np.ndarray:
349
+ assay_vec, metadata_tensor = self.encode_assay(assay_text)
350
+ molecule_tensor = torch.from_numpy(feature_matrix).to(self.device)
351
+ with torch.inference_mode():
352
+ assay_repeat = assay_vec.repeat(molecule_tensor.size(0), 1)
353
+ metadata_repeat = metadata_tensor.repeat(molecule_tensor.size(0), 1) if metadata_tensor is not None else None
354
+ scores = self.head(assay_repeat, molecule_tensor, metadata_repeat)
355
+ return scores.detach().cpu().numpy()
356
+
357
+
358
+ def build_molecule_feature_vector(smiles: str, feature_spec: dict[str, Any]) -> np.ndarray | None:
359
+ standardized = standardize_smiles_v2(smiles)
360
+ if standardized is None:
361
+ return None
362
+ mol = Chem.MolFromSmiles(standardized)
363
+ if mol is None:
364
+ return None
365
+ parts: list[np.ndarray] = []
366
+ for radius in feature_spec.get("fingerprint_radii", [2, 3]):
367
+ parts.append(
368
+ _morgan_bits_from_mol(
369
+ mol,
370
+ radius=int(radius),
371
+ n_bits=int(feature_spec.get("fingerprint_bits", 2048)),
372
+ use_chirality=bool(feature_spec.get("use_chirality", True)),
373
+ ).astype(np.float32)
374
+ )
375
+ if feature_spec.get("use_maccs", True):
376
+ parts.append(_maccs_bits_from_mol(mol).astype(np.float32))
377
+ if feature_spec.get("use_rdkit_descriptors", True):
378
+ descriptor_values = _molecule_descriptor_vector(
379
+ mol,
380
+ names=tuple(feature_spec.get("descriptor_names", DEFAULT_DESCRIPTOR_NAMES)),
381
+ )
382
+ descriptor_mean = np.asarray(feature_spec["descriptor_mean"], dtype=np.float32)
383
+ descriptor_std = np.asarray(feature_spec["descriptor_std"], dtype=np.float32)
384
+ parts.append(((descriptor_values - descriptor_mean) / (descriptor_std + 1e-6)).astype(np.float32))
385
+ if not parts:
386
+ return None
387
+ return np.concatenate(parts, axis=0).astype(np.float32)
388
+
389
+
390
+ def load_compatibility_model(model_dir: str | Path, *, device: str | None = None) -> CompatibilityModel:
391
+ model_path = Path(model_dir)
392
+ training_metadata = json.loads((model_path / "training_metadata.json").read_text())
393
+ checkpoint = torch.load(model_path / "best_model.pt", map_location="cpu")
394
+ assay_model_name = training_metadata["config"]["assay_model_name"]
395
+ assay_encoder = SentenceTransformer(assay_model_name, device=device or ("cuda" if torch.cuda.is_available() else "cpu"))
396
+ return CompatibilityModel(assay_encoder, training_metadata, checkpoint["head_state"], device=device)
397
+
398
+
399
+ def load_compatibility_model_from_hub(repo_id: str, *, device: str | None = None) -> CompatibilityModel:
400
+ snapshot_path = snapshot_download(repo_id=repo_id, repo_type="model", allow_patterns=["best_model.pt", "training_metadata.json"])
401
+ return load_compatibility_model(snapshot_path, device=device)
402
+
403
+
404
+ def rank_compounds(model: CompatibilityModel, assay_text: str, smiles_list: list[str], *, top_k: int | None = None) -> list[dict[str, Any]]:
405
+ valid_inputs: list[tuple[str, str, np.ndarray]] = []
406
+ invalid_rows: list[dict[str, Any]] = []
407
+ for item in smiles_list:
408
+ feature_vec = build_molecule_feature_vector(item, model.feature_spec)
409
+ standardized = standardize_smiles_v2(item)
410
+ if feature_vec is None or standardized is None:
411
+ invalid_rows.append({"input_smiles": item, "valid": False, "error": "invalid_smiles"})
412
+ continue
413
+ valid_inputs.append((item, standardized, feature_vec))
414
+ valid_rows: list[dict[str, Any]] = []
415
+ if valid_inputs:
416
+ feature_matrix = np.stack([entry[2] for entry in valid_inputs], axis=0).astype(np.float32)
417
+ scores = model.score_feature_matrix(assay_text, feature_matrix)
418
+ for (input_smiles, standardized, _), score in zip(valid_inputs, scores):
419
+ valid_rows.append(
420
+ {
421
+ "input_smiles": input_smiles,
422
+ "canonical_smiles": standardized,
423
+ "smiles_hash": smiles_sha256(standardized),
424
+ "score": float(score),
425
+ "valid": True,
426
+ }
427
+ )
428
+ valid_rows.sort(key=lambda item: item["score"], reverse=True)
429
+ if top_k:
430
+ valid_rows = valid_rows[:top_k]
431
+ return valid_rows + invalid_rows
432
+
433
+
434
+ def list_softmax_scores(scores: list[float], temperature: float = 1.0) -> list[float]:
435
+ values = np.asarray(scores, dtype=np.float32) / max(float(temperature), 1e-6)
436
+ values = values - values.max()
437
+ probs = np.exp(values)
438
+ probs = probs / probs.sum()
439
+ return probs.tolist()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.2
2
+ transformers>=4.50
3
+ sentence-transformers>=3.0
4
+ huggingface_hub>=0.30
5
+ numpy>=1.26
6
+ rdkit>=2023.9