22import os
33import pickle
44
5+ import math
56import torch
7+ from typing import Literal
68
7- from chebai .models .electra import extract_class_hierarchy
8-
9- IMPLICATION_CACHE_FILE = "chebi.cache"
9+ from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor , ChEBIOver100
1010
1111
1212class ImplicationLoss (torch .nn .Module ):
1313 def __init__ (
14- self , path_to_chebi , path_to_label_names , base_loss : torch .nn .Module = None
14+ self ,
15+ data_extractor : _ChEBIDataExtractor ,
16+ base_loss : torch .nn .Module = None ,
17+ tnorm : Literal ["product" , "lukasiewicz" , "xu19" ] = "product" ,
18+ impl_loss_weight = 0.1 , # weight of implication loss in relation to base_loss
19+ pos_scalar = 1 ,
20+ pos_epsilon = 0.01 ,
1521 ):
1622 super ().__init__ ()
23+ self .data_extractor = data_extractor
1724 self .base_loss = base_loss
18- label_names = _load_label_names (path_to_label_names )
19- hierarchy = _load_implications (path_to_chebi )
20- implication_filter = _build_implication_filter (label_names , hierarchy )
25+ self .implication_cache_file = f"implications_{ self .data_extractor .name } .cache"
26+ self .label_names = _load_label_names (
27+ os .path .join (data_extractor .raw_dir , "classes.txt" )
28+ )
29+ self .hierarchy = self ._load_implications (
30+ os .path .join (data_extractor .raw_dir , "chebi.obo" )
31+ )
32+ implication_filter = _build_implication_filter (self .label_names , self .hierarchy )
2133 self .implication_filter_l = implication_filter [:, 0 ]
2234 self .implication_filter_r = implication_filter [:, 1 ]
35+ self .tnorm = tnorm
36+ self .impl_weight = impl_loss_weight
37+ self .pos_scalar = pos_scalar
38+ self .eps = pos_epsilon
2339
2440 def forward (self , input , target , ** kwargs ):
2541 nnl = kwargs .pop ("non_null_labels" , None )
@@ -36,40 +52,77 @@ def forward(self, input, target, **kwargs):
3652 r = pred [:, self .implication_filter_r ]
3753 # implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0))
3854 implication_loss = self ._calculate_implication_loss (l , r )
39- return base_loss + implication_loss
55+
56+ return (
57+ base_loss + self .impl_weight * implication_loss ,
58+ base_loss ,
59+ implication_loss ,
60+ )
4061
4162 def _calculate_implication_loss (self , l , r ):
42- capped_difference = torch .relu (l - r )
63+ assert not l .isnan ().any ()
64+ assert not r .isnan ().any ()
65+ if self .pos_scalar != 1 :
66+ l = (
67+ torch .pow (l + self .eps , 1 / self .pos_scalar )
68+ - math .pow (self .eps , 1 / self .pos_scalar )
69+ ) / (
70+ math .pow (1 + self .eps , 1 / self .pos_scalar )
71+ - math .pow (self .eps , 1 / self .pos_scalar )
72+ )
73+ r = torch .pow (r , self .pos_scalar )
74+ if self .tnorm == "product" :
75+ individual_loss = l * (1 - r )
76+ elif self .tnorm == "xu19" :
77+ individual_loss = - torch .log (1 - l * (1 - r ))
78+ elif self .tnorm == "lukasiewicz" :
79+ individual_loss = torch .relu (l - r )
80+ else :
81+ raise NotImplementedError (f"Unknown tnorm { self .tnorm } " )
82+
4383 return torch .mean (
44- torch .sum (
45- (torch .softmax (capped_difference , dim = - 1 ) * capped_difference ), dim = - 1
46- ),
84+ torch .sum (individual_loss , dim = - 1 ),
4785 dim = 0 ,
4886 )
4987
88+ def _load_implications (self , path_to_chebi ):
89+ if os .path .isfile (self .implication_cache_file ):
90+ with open (self .implication_cache_file , "rb" ) as fin :
91+ hierarchy = pickle .load (fin )
92+ else :
93+ hierarchy = self .data_extractor .extract_class_hierarchy (path_to_chebi )
94+ with open (self .implication_cache_file , "wb" ) as fout :
95+ pickle .dump (hierarchy , fout )
96+ return hierarchy
97+
5098
5199class DisjointLoss (ImplicationLoss ):
52100 def __init__ (
53101 self ,
54- path_to_chebi ,
55- path_to_label_names ,
56- path_to_disjointedness ,
102+ path_to_disjointness ,
103+ data_extractor : _ChEBIDataExtractor ,
57104 base_loss : torch .nn .Module = None ,
105+ disjoint_loss_weight = 100 ,
106+ ** kwargs ,
58107 ):
59- super ().__init__ (path_to_chebi , path_to_label_names , base_loss )
60- label_names = _load_label_names (path_to_label_names )
61- hierarchy = _load_implications (path_to_chebi )
108+ super ().__init__ (data_extractor , base_loss , ** kwargs )
62109 self .disjoint_filter_l , self .disjoint_filter_r = _build_disjointness_filter (
63- path_to_disjointedness , label_names , hierarchy
110+ path_to_disjointness , self . label_names , self . hierarchy
64111 )
112+ self .disjoint_weight = disjoint_loss_weight
65113
66114 def forward (self , input , target , ** kwargs ):
67- loss = super ().forward (input , target , ** kwargs )
115+ loss , base_loss , impl_loss = super ().forward (input , target , ** kwargs )
68116 pred = torch .sigmoid (input )
69117 l = pred [:, self .disjoint_filter_l ]
70118 r = pred [:, self .disjoint_filter_r ]
71119 disjointness_loss = self ._calculate_implication_loss (l , 1 - r )
72- return loss + disjointness_loss
120+ return (
121+ loss + self .disjoint_weight * disjointness_loss ,
122+ base_loss ,
123+ impl_loss ,
124+ disjointness_loss ,
125+ )
73126
74127
75128def _load_label_names (path_to_label_names ):
@@ -78,17 +131,6 @@ def _load_label_names(path_to_label_names):
78131 return label_names
79132
80133
81- def _load_implications (path_to_chebi , implication_cache = IMPLICATION_CACHE_FILE ):
82- if os .path .isfile (implication_cache ):
83- with open (implication_cache , "rb" ) as fin :
84- hierarchy = pickle .load (fin )
85- else :
86- hierarchy = extract_class_hierarchy (path_to_chebi )
87- with open (implication_cache , "wb" ) as fout :
88- pickle .dump (hierarchy , fout )
89- return hierarchy
90-
91-
92134def _build_implication_filter (label_names , hierarchy ):
93135 return torch .tensor (
94136 [
@@ -100,24 +142,33 @@ def _build_implication_filter(label_names, hierarchy):
100142 )
101143
102144
103- def _build_disjointness_filter (path_to_disjointedness , label_names , hierarchy ):
145+ def _build_disjointness_filter (path_to_disjointness , label_names , hierarchy ):
104146 disjoints = set ()
105147 label_dict = dict (map (reversed , enumerate (label_names )))
106148
107- with open (path_to_disjointedness , "rt" ) as fin :
149+ with open (path_to_disjointness , "rt" ) as fin :
108150 reader = csv .reader (fin )
109151 for l1_raw , r1_raw in reader :
110152 l1 = int (l1_raw )
111153 r1 = int (r1_raw )
154+ if l1 == 36233 and r1 == 63353 :
155+ # ignore disaccharide-disaccharide derivative disjointness axiom
156+ continue
112157 disjoints .update (
113158 {
114159 (label_dict [l2 ], label_dict [r2 ])
115- for r2 in hierarchy .succ [r1 ]
160+ for r2 in list ( hierarchy .succ [ r1 ]) + [r1 ]
116161 if r2 in label_names
117- for l2 in hierarchy .succ [l1 ]
118- if l2 in label_names and l2 < r2
162+ for l2 in list ( hierarchy .succ [ l1 ]) + [l1 ]
163+ if l2 in label_names
119164 }
120165 )
121166
122167 dis_filter = torch .tensor (list (disjoints ))
123168 return dis_filter [:, 0 ], dis_filter [:, 1 ]
169+
170+
171+ if __name__ == "__main__" :
172+ loss = DisjointLoss (
173+ os .path .join ("data" , "disjoint.csv" ), ChEBIOver100 (chebi_version = 227 )
174+ )
0 commit comments