@@ -135,6 +135,8 @@ def get_chebi_graph(data_module, label_names):
135135 chebi_graph = data_module ._extract_class_hierarchy (
136136 os .path .join (data_module .raw_dir , "chebi.obo" )
137137 )
138+ if label_names is None :
139+ return chebi_graph
138140 return chebi_graph .subgraph ([int (n ) for n in label_names ])
139141 print (
140142 f"Failed to retrieve ChEBI graph, { os .path .join (data_module .raw_dir , 'chebi.obo' )} not found"
@@ -190,32 +192,38 @@ class PredictionSmoother:
190192 """Removes implication and disjointness violations from predictions"""
191193
192194 def __init__ (self , dataset , label_names = None , disjoint_files = None ):
193- if label_names :
194- self .label_names = label_names
195- else :
196- self .label_names = get_label_names (dataset )
197- self .chebi_graph = get_chebi_graph (dataset , self .label_names )
195+ self .chebi_graph = get_chebi_graph (dataset , None )
196+ self .set_label_names (label_names )
198197 self .disjoint_groups = get_disjoint_groups (disjoint_files )
199198
199+ def set_label_names (self , label_names ):
200+ if label_names is not None :
201+ self .label_names = [int (label ) for label in label_names ]
202+ chebi_subgraph = self .chebi_graph .subgraph (self .label_names )
203+ self .label_successors = torch .zeros (
204+ (len (self .label_names ), len (self .label_names )), dtype = torch .bool
205+ )
206+ for i , label in enumerate (self .label_names ):
207+ self .label_successors [i , i ] = 1
208+ for p in chebi_subgraph .successors (label ):
209+ if p in self .label_names :
210+ self .label_successors [i , self .label_names .index (p )] = 1
211+ self .label_successors = self .label_successors .unsqueeze (0 )
212+
200213 def __call__ (self , preds ):
201214 preds_sum_orig = torch .sum (preds )
202- for i , label in enumerate (self .label_names ):
203- succs = [
204- self .label_names .index (str (p ))
205- for p in self .chebi_graph .successors (int (label ))
206- ] + [i ]
207- if len (succs ) > 0 :
208- preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
215+ # step 1: apply implications: for each class, set prediction to max of itself and all successors
216+ preds = preds .unsqueeze (1 )
217+ preds_masked_succ = torch .where (self .label_successors , preds , 0 )
218+ preds = preds_masked_succ .max (dim = 2 ).values
209219 if torch .sum (preds ) != preds_sum_orig :
210220 print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
211221 preds_sum_orig = torch .sum (preds )
212222 # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
213223 preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
214224 for disj_group in self .disjoint_groups :
215225 disj_group = [
216- self .label_names .index (str (g ))
217- for g in disj_group
218- if g in self .label_names
226+ self .label_names .index (g ) for g in disj_group if g in self .label_names
219227 ]
220228 if len (disj_group ) > 1 :
221229 old_preds = preds [:, disj_group ]
@@ -232,26 +240,17 @@ def __call__(self, preds):
232240 print (
233241 f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
234242 )
243+ if torch .sum (preds ) != preds_sum_orig :
244+ print (f"Preds change (step 2): { torch .sum (preds ) - preds_sum_orig } " )
235245 preds_sum_orig = torch .sum (preds )
236246 # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
237- for i , label in enumerate (self .label_names ):
238- predecessors = [i ] + [
239- self .label_names .index (str (p ))
240- for p in self .chebi_graph .predecessors (int (label ))
241- ]
242- lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
243- preds [:, i ] = lowest_predecessors .values
244- for idx_idx , idx in enumerate (lowest_predecessors .indices ):
245- if idx > 0 :
246- print (
247- f"class { label } : changed prediction of sample { idx_idx } to value of class "
248- f"{ self .label_names [predecessors [idx ]]} ({ preds [idx_idx , i ].item ():.2f} )"
249- )
250- if torch .sum (preds ) != preds_sum_orig :
251- print (
252- f"Preds change (step 3) for { label } : { torch .sum (preds ) - preds_sum_orig } "
253- )
254- preds_sum_orig = torch .sum (preds )
247+ preds = preds .unsqueeze (1 )
248+ preds_masked_predec = torch .where (
249+ torch .transpose (self .label_successors , 1 , 2 ), preds , 1
250+ )
251+ preds = preds_masked_predec .min (dim = 2 ).values
252+ if torch .sum (preds ) != preds_sum_orig :
253+ print (f"Preds change (step 3): { torch .sum (preds ) - preds_sum_orig } " )
255254 return preds
256255
257256
0 commit comments