11import argparse
2+ import json
3+ import logging
24from pathlib import Path
35from typing import Iterable
46
57import numpy as np
68from datasets import load_dataset
79from more_itertools import batched
8- from reach import Reach
910from sentence_transformers import SentenceTransformer
1011from tqdm import tqdm
1112
1213_SAVE_INTERVAL = 10
1314_MAX_MEANS = 1100000
1415
16+ logger = logging .getLogger (__name__ )
17+
18+
19+ def save_data (means : list [np .ndarray ], txts : list [str ], base_filepath : str ) -> None :
20+ """
21+ Save the means and texts to separate files.
22+
23+ :param means: List of numpy arrays representing the mean embeddings.
24+ :param txts: List of texts corresponding to the embeddings.
25+ :param base_filepath: Base path for the output files.
26+ """
27+ vectors_filepath = base_filepath + "_vectors.npy"
28+ items_filepath = base_filepath + "_items.json"
29+
30+ # Save the embeddings (vectors) to a .npy file
31+ np .save (vectors_filepath , np .array (means ))
32+ # Save the texts to a JSON file
33+ with open (items_filepath , "w" ) as f :
34+ json .dump ({"items" : txts }, f )
35+ logger .info (f"Saved { len (txts )} texts to { items_filepath } and vectors to { vectors_filepath } " )
36+
1537
1638def featurize (texts : Iterable [str ], model : SentenceTransformer , output_dir : str ) -> None :
1739 """
@@ -35,55 +57,76 @@ def featurize(texts: Iterable[str], model: SentenceTransformer, output_dir: str)
3557
3658 for index , batch in enumerate (tqdm (batched (texts , 32 ))):
3759 i = index // _SAVE_INTERVAL
38- if ( out_path / f"featurized_{ i } .json" ). exists ():
39- continue
40- # Consume the generator
60+ base_filename = f"featurized_{ i } "
61+ vectors_filepath = out_path / ( base_filename + "_vectors.npy" )
62+ items_filepath = out_path / ( base_filename + "_items.json" )
4163 list_batch = [x ["text" ].strip () for x in batch if x .get ("text" )]
64+ if not list_batch :
65+ continue # Skip empty batches
66+
67+ # Encode the batch to get token embeddings
68+ token_embeddings = model .encode (
69+ list_batch ,
70+ output_value = "token_embeddings" ,
71+ convert_to_tensor = True ,
72+ )
4273
43- # Already truncated to model max_length
74+ # Tokenize the batch to get input IDs
4475 tokenized_ids = model .tokenize (list_batch )["input_ids" ]
45- token_embeddings : list [np .ndarray ] = [
46- x .cpu ().numpy () for x in model .encode (list_batch , output_value = "token_embeddings" , convert_to_numpy = True )
47- ]
4876
49- for tokenized_id , token_embedding in zip (tokenized_ids , token_embeddings , strict = True ):
50- # Truncate to actual length of vectors, remove CLS and SEP.
51- text = model .tokenizer .decode (tokenized_id [1 : len (token_embedding ) - 1 ])
77+ for tokenized_id , token_embedding in zip (tokenized_ids , token_embeddings ):
78+ # Convert token IDs to tokens (excluding special tokens)
79+ token_ids = tokenized_id [1 :- 1 ]
80+ # Decode tokens to text
81+ text = model .tokenizer .decode (token_ids )
5282 if text in seen :
5383 continue
5484 seen .add (text )
55- mean = np .mean (token_embedding [1 :- 1 ], axis = 0 )
85+ # Get the corresponding token embeddings (excluding special tokens)
86+ token_embeds = token_embedding [1 :- 1 ]
87+ # Convert embeddings to NumPy arrays
88+ token_embeds = token_embeds .detach ().cpu ().numpy ()
89+ # Compute the mean of the token embeddings
90+ mean = np .mean (token_embeds , axis = 0 )
5691 txts .append (text )
5792 means .append (mean )
5893 total_means += 1
5994
6095 if total_means >= _MAX_MEANS :
61- # Save the final batch and stop
62- r = Reach (means , txts )
63- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
96+ save_data (means , txts , str (out_path / base_filename ))
6497 return
6598
6699 if index > 0 and (index + 1 ) % _SAVE_INTERVAL == 0 :
67- r = Reach (means , txts )
68- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
100+ save_data (means , txts , str (out_path / base_filename ))
69101 txts = []
70102 means = []
71103 seen = set ()
72104 else :
73- if means :
74- r = Reach (means , txts )
75- r .save (out_path / f"featurized_{ (index // _SAVE_INTERVAL )} .json" )
105+ if txts and means :
106+ save_data (means , txts , str (out_path / base_filename ))
76107
77108
78- if __name__ == "__main__" :
79- parser = argparse .ArgumentParser (description = "Train a Model2Vec using tokenlearn." )
109+ def main () -> None :
110+ """Main function to featurize texts using a sentence transformer."""
111+ parser = argparse .ArgumentParser (description = "Featurize texts using a sentence transformer." )
80112 parser .add_argument (
81113 "--model-name" ,
82114 type = str ,
83115 default = "baai/bge-base-en-v1.5" ,
84116 help = "The model name for distillation (e.g., 'baai/bge-base-en-v1.5')." ,
85117 )
118+ parser .add_argument (
119+ "--output-dir" ,
120+ type = str ,
121+ default = "data/c4_bgebase" ,
122+ help = "Directory to save the featurized texts." ,
123+ )
86124 args = parser .parse_args ()
125+
87126 model = SentenceTransformer (args .model_name )
88127 dataset = load_dataset ("allenai/c4" , name = "en" , split = "train" , streaming = True )
89- featurize (dataset , model , "data/c4_bgebase" )
128+ featurize (dataset , model , args .output_dir )
129+
130+
131+ if __name__ == "__main__" :
132+ main ()
0 commit comments