Skip to content

Commit 94f6710

Browse files
authored
Merge pull request #88 from ChEB-AI/index-reader-abstract-class
Token Indexer Reader abstract class
2 parents b39d8de + 42ed09d commit 94f6710

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

chebai/preprocessing/reader.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import sys
3+
from abc import ABC
34
from itertools import islice
4-
from typing import Any, Dict, List, Optional, Tuple
5+
from typing import Any, Dict, List, Optional
56

67
import deepsmiles
78
import selfies as sf
@@ -119,23 +120,14 @@ def on_finish(self) -> None:
119120
return
120121

121122

122-
class ChemDataReader(DataReader):
123+
class TokenIndexerReader(DataReader, ABC):
123124
"""
124-
Data reader for chemical data using SMILES tokens.
125+
Abstract base class for reading tokenized data and mapping tokens to unique indices.
125126
126-
Args:
127-
collator_kwargs: Optional dictionary of keyword arguments for the collator.
128-
token_path: Optional path for the token file.
129-
kwargs: Additional keyword arguments.
127+
This class maintains a cache of token-to-index mappings that can be extended during runtime,
128+
and saves new tokens to a persistent file at the end of processing.
130129
"""
131130

132-
COLLATOR = RaggedCollator
133-
134-
@classmethod
135-
def name(cls) -> str:
136-
"""Returns the name of the data reader."""
137-
return "smiles_token"
138-
139131
def __init__(self, *args, **kwargs):
140132
super().__init__(*args, **kwargs)
141133
with open(self.token_path, "r") as pk:
@@ -150,21 +142,9 @@ def _get_token_index(self, token: str) -> int:
150142
self.cache[(str(token))] = len(self.cache)
151143
return self.cache[str(token)] + EMBEDDING_OFFSET
152144

153-
def _read_data(self, raw_data: str) -> List[int]:
154-
"""
155-
Reads and tokenizes raw SMILES data into a list of token indices.
156-
157-
Args:
158-
raw_data (str): The raw SMILES string to be tokenized.
159-
160-
Returns:
161-
List[int]: A list of integers representing the indices of the SMILES tokens.
162-
"""
163-
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
164-
165145
def on_finish(self) -> None:
166146
"""
167-
Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
147+
Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
168148
"""
169149
print(f"first 10 tokens: {list(islice(self.cache, 10))}")
170150

@@ -188,6 +168,31 @@ def on_finish(self) -> None:
188168
pk.writelines([f"{c}\n" for c in new_tokens])
189169

190170

171+
class ChemDataReader(TokenIndexerReader):
172+
"""
173+
Data reader for chemical data using SMILES tokens.
174+
"""
175+
176+
COLLATOR = RaggedCollator
177+
178+
@classmethod
179+
def name(cls) -> str:
180+
"""Returns the name of the data reader."""
181+
return "smiles_token"
182+
183+
def _read_data(self, raw_data: str) -> List[int]:
184+
"""
185+
Reads and tokenizes raw SMILES data into a list of token indices.
186+
187+
Args:
188+
raw_data (str): The raw SMILES string to be tokenized.
189+
190+
Returns:
191+
List[int]: A list of integers representing the indices of the SMILES tokens.
192+
"""
193+
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
194+
195+
191196
class DeepChemDataReader(ChemDataReader):
192197
"""
193198
Data reader for chemical data using DeepSMILES tokens.

0 commit comments

Comments
 (0)