diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index 4ef58a3..bf6c39b 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import List, Optional +import tqdm + from chebifier import modelwise_smiles_lru_cache from chebifier.prediction_models import BasePredictor @@ -26,14 +28,22 @@ def __init__( def predict_smiles_list(self, smiles_list: list[str]) -> list: from c3p import classifier as c3p_classifier - result_list = c3p_classifier.classify( - list(smiles_list), - self.program_directory, - self.chemical_classes, - strict=False, - ) + result_list = [] + for batch_start in tqdm.tqdm( + range(0, len(smiles_list), 32), desc="Classifying with C3P" + ): + batch_end = min(batch_start + 32, len(smiles_list)) + result_list.extend( + c3p_classifier.classify( + smiles_list[batch_start:batch_end], + self.program_directory, + self.chemical_classes, + strict=False, + ) + ) + result_reformatted = [dict() for _ in range(len(smiles_list))] - for result in result_list: + for result in tqdm.tqdm(result_list, desc="Reformatting C3P results"): chebi_id = result.class_id.split(":")[1] result_reformatted[smiles_list.index(result.input_smiles)][ chebi_id @@ -61,13 +71,13 @@ def explain_smiles(self, smiles): highlights.append( ( "text", - f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}", + f"For {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}", ) ) highlights = [ ( "text", - f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:", + f"C3P made positive predictions for {len(highlights)} classes. {'The explanations are as follows:' if len(highlights) > 0 else ''}", ) ] + highlights diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index c8ad75a..d68c9cd 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -6,7 +6,7 @@ from chebifier import modelwise_smiles_lru_cache from chebifier.prediction_models import BasePredictor -from chebifier.utils import load_chebi_graph +from chebifier.utils import _smiles_to_mol, load_chebi_graph class ChEBILookupPredictor(BasePredictor): @@ -50,7 +50,7 @@ def build_smiles_lookup(self): ).items(): if smiles is not None: try: - mol = Chem.MolFromSmiles(smiles) + mol = _smiles_to_mol(smiles) if mol is None: print( f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}" @@ -72,7 +72,7 @@ def build_smiles_lookup(self): def predict_smiles(self, smiles: str) -> Optional[dict]: if not smiles: return None - mol = Chem.MolFromSmiles(smiles) + mol = _smiles_to_mol(smiles) if mol is None: return None canonical_smiles = Chem.MolToSmiles(mol) @@ -110,7 +110,7 @@ def info_text(self): return self._description def explain_smiles(self, smiles: str) -> dict: - mol = Chem.MolFromSmiles(smiles) + mol = _smiles_to_mol(smiles) if mol is None: return { "highlights": [ diff --git a/chebifier/utils.py b/chebifier/utils.py index fbdcde5..96f4a80 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -1,3 +1,4 @@ +import functools import importlib.resources import os import pickle @@ -6,6 +7,7 @@ import networkx as nx import requests import yaml +from rdkit import Chem from chebifier.hugging_face import download_model_files @@ -156,6 +158,18 @@ def process_config(config, model_registry): return new_config +@functools.lru_cache(maxsize=128) +def _smiles_to_mol(smiles: str): + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is not None: + # turn aromatic bond types into single/double + try: + Chem.Kekulize(mol) + except Chem.KekulizeException as e: + print(f"Failed to Kekulize {smiles}: {e}") + return mol + + if __name__ == "__main__": chebi_graph = build_chebi_graph(chebi_version=244) os.makedirs(os.path.join("data", "chebi_v244"), exist_ok=True)