11import os
22import sys
3+ from abc import ABC
34from itertools import islice
4- from typing import Any , Dict , List , Optional , Tuple
5+ from typing import Any , Dict , List , Optional
56
67import deepsmiles
78import selfies as sf
@@ -119,23 +120,14 @@ def on_finish(self) -> None:
119120 return
120121
121122
122- class ChemDataReader (DataReader ):
123+ class TokenIndexerReader (DataReader , ABC ):
123124 """
124- Data reader for chemical data using SMILES tokens.
125+ Abstract base class for reading tokenized data and mapping tokens to unique indices .
125126
126- Args:
127- collator_kwargs: Optional dictionary of keyword arguments for the collator.
128- token_path: Optional path for the token file.
129- kwargs: Additional keyword arguments.
127+ This class maintains a cache of token-to-index mappings that can be extended during runtime,
128+ and saves new tokens to a persistent file at the end of processing.
130129 """
131130
132- COLLATOR = RaggedCollator
133-
134- @classmethod
135- def name (cls ) -> str :
136- """Returns the name of the data reader."""
137- return "smiles_token"
138-
139131 def __init__ (self , * args , ** kwargs ):
140132 super ().__init__ (* args , ** kwargs )
141133 with open (self .token_path , "r" ) as pk :
@@ -150,21 +142,9 @@ def _get_token_index(self, token: str) -> int:
150142 self .cache [(str (token ))] = len (self .cache )
151143 return self .cache [str (token )] + EMBEDDING_OFFSET
152144
153- def _read_data (self , raw_data : str ) -> List [int ]:
154- """
155- Reads and tokenizes raw SMILES data into a list of token indices.
156-
157- Args:
158- raw_data (str): The raw SMILES string to be tokenized.
159-
160- Returns:
161- List[int]: A list of integers representing the indices of the SMILES tokens.
162- """
163- return [self ._get_token_index (v [1 ]) for v in _tokenize (raw_data )]
164-
165145 def on_finish (self ) -> None :
166146 """
167- Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
147+ Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
168148 """
169149 print (f"first 10 tokens: { list (islice (self .cache , 10 ))} " )
170150
@@ -188,6 +168,31 @@ def on_finish(self) -> None:
188168 pk .writelines ([f"{ c } \n " for c in new_tokens ])
189169
190170
171+ class ChemDataReader (TokenIndexerReader ):
172+ """
173+ Data reader for chemical data using SMILES tokens.
174+ """
175+
176+ COLLATOR = RaggedCollator
177+
178+ @classmethod
179+ def name (cls ) -> str :
180+ """Returns the name of the data reader."""
181+ return "smiles_token"
182+
183+ def _read_data (self , raw_data : str ) -> List [int ]:
184+ """
185+ Reads and tokenizes raw SMILES data into a list of token indices.
186+
187+ Args:
188+ raw_data (str): The raw SMILES string to be tokenized.
189+
190+ Returns:
191+ List[int]: A list of integers representing the indices of the SMILES tokens.
192+ """
193+ return [self ._get_token_index (v [1 ]) for v in _tokenize (raw_data )]
194+
195+
191196class DeepChemDataReader (ChemDataReader ):
192197 """
193198 Data reader for chemical data using DeepSMILES tokens.
0 commit comments