@@ -233,40 +233,54 @@ def report(
233233 hol_sample_size or float ("inf" ),
234234 )
235235
236- if max_sample_size_embeddings_final >= 10_000 and max_sample_size_embeddings is None :
236+ if max_sample_size_embeddings_final > 10_000 and max_sample_size_embeddings is None :
237237 warnings .warn (
238238 UserWarning (
239239 "More than 10k embeddings will be calculated per dataset. "
240240 "Consider setting a limit via `max_sample_size_embeddings`."
241241 )
242242 )
243243
244- def _calc_pull_embeds (
245- df_tgt : pd .DataFrame , df_ctx : pd .DataFrame , progress_from : int , progress_to : int
246- ) -> np .ndarray :
247- strings = pull_data_for_embeddings (
248- df_tgt = df_tgt ,
249- df_ctx = df_ctx ,
244+ _LOG .info ("calculate embeddings for synthetic" )
245+ syn_embeds = calculate_embeddings (
246+ strings = pull_data_for_embeddings (
247+ df_tgt = syn_tgt_data ,
248+ df_ctx = syn_ctx_data ,
250249 ctx_primary_key = ctx_primary_key ,
251250 tgt_context_key = tgt_context_key ,
252251 max_sample_size = max_sample_size_embeddings_final ,
253- )
254- # split into buckets for calculating embeddings to avoid memory issues and report continuous progress
255- buckets = np .array_split (strings , progress_to - progress_from )
256- buckets = [b for b in buckets if len (b ) > 0 ]
257- embeds = []
258- for i , bucket in enumerate (buckets , 1 ):
259- embeds += [calculate_embeddings (bucket .tolist ())]
260- progress .update (completed = progress_from + i , total = 100 )
261- progress .update (completed = progress_to , total = 100 )
262- embeds = np .concatenate (embeds , axis = 0 )
263- _LOG .info (f"calculated embeddings { embeds .shape } " )
264- return embeds
265-
266- syn_embeds = _calc_pull_embeds (df_tgt = syn_tgt_data , df_ctx = syn_ctx_data , progress_from = 20 , progress_to = 40 )
267- trn_embeds = _calc_pull_embeds (df_tgt = trn_tgt_data , df_ctx = trn_ctx_data , progress_from = 40 , progress_to = 60 )
252+ ),
253+ progress = progress ,
254+ progress_from = 20 ,
255+ progress_to = 40 ,
256+ )
257+ _LOG .info ("calculate embeddings for training" )
258+ trn_embeds = calculate_embeddings (
259+ strings = pull_data_for_embeddings (
260+ df_tgt = trn_tgt_data ,
261+ df_ctx = trn_ctx_data ,
262+ ctx_primary_key = ctx_primary_key ,
263+ tgt_context_key = tgt_context_key ,
264+ max_sample_size = max_sample_size_embeddings_final ,
265+ ),
266+ progress = progress ,
267+ progress_from = 40 ,
268+ progress_to = 60 ,
269+ )
268270 if hol_tgt_data is not None :
269- hol_embeds = _calc_pull_embeds (df_tgt = hol_tgt_data , df_ctx = hol_ctx_data , progress_from = 60 , progress_to = 80 )
271+ _LOG .info ("calculate embeddings for holdout" )
272+ hol_embeds = calculate_embeddings (
273+ strings = pull_data_for_embeddings (
274+ df_tgt = hol_tgt_data ,
275+ df_ctx = hol_ctx_data ,
276+ ctx_primary_key = ctx_primary_key ,
277+ tgt_context_key = tgt_context_key ,
278+ max_sample_size = max_sample_size_embeddings_final ,
279+ ),
280+ progress = progress ,
281+ progress_from = 60 ,
282+ progress_to = 80 ,
283+ )
270284 else :
271285 hol_embeds = None
272286 progress .update (completed = 80 , total = 100 )
0 commit comments