@@ -328,6 +328,14 @@ def forward(self, data, **kwargs):
328328 target_mask = data .get ("target_mask" ),
329329 )
330330
331+ def softabs (x , eps = 0.01 ):
332+ return (x ** 2 + eps )** 0.5 - eps ** 0.5
333+
334+ def in_cone_parts (vectors , cone_axes , cone_arcs ):
335+ theta_L = cone_axes + cone_arcs
336+ theta_R = cone_axes - cone_arcs
337+ return ((softabs (vectors - theta_L ) + softabs (vectors - theta_R )) - cone_arcs )/ (2 * pi - cone_arcs )
338+
331339class ConeLoss :
332340
333341 def __init__ (self , center_scaling = 0.1 ):
@@ -338,40 +346,10 @@ def negate(self, ax, arc):
338346 offset [ax >= 0 ] *= - 1
339347 return ax + offset , pi - arc
340348
341- @classmethod
342- def cal_logit_cone (cls , entity_embedding , query_axis_embedding , query_arg_embedding , center_scaling = 0.2 ):
343- """Cone distance from https://github.com/MIRALab-USTC/QE-ConE
344- :param entity_embedding:
345- :param query_axis_embedding:
346- :param query_arg_embedding:
347- :return:
348- """
349-
350- e = entity_embedding .unsqueeze (1 )
351-
352- distance2axis = torch .abs (torch .sin ((e - query_axis_embedding ) / 2 ))
353- distance_base = torch .abs (torch .sin (query_arg_embedding / 2 ))
354-
355- indicator_in = distance2axis < distance_base
356- distance_out = torch .min (torch .abs (torch .sin (e - (query_axis_embedding - query_arg_embedding ) / 2 )), torch .abs (torch .sin (e - (query_axis_embedding + query_arg_embedding ) / 2 )))
357- distance_out [indicator_in ] = 0.
358-
359- distance_in = torch .min (distance2axis , distance_base )
360-
361- distance = torch .norm (distance_out , p = 1 , dim = - 1 )/ e .shape [- 1 ] + center_scaling * torch .norm (distance_in , p = 1 , dim = - 1 )/ e .shape [- 1 ]
362-
363- return distance
364-
365349 def __call__ (self , target , input ):
350+ predicted_vectors = input ["predicted_vectors" ]
366351 cone_axes = input ["cone_axes" ]
367352 cone_arcs = input ["cone_arcs" ]
368-
369- negated_cone_axes , negated_cone_arcs = self .negate (cone_arcs , cone_axes )
370-
371- predicted_vectors = input ["predicted_vectors" ]
372- loss = torch .zeros ((predicted_vectors .shape [0 ], cone_axes .shape [1 ]), device = target .get_device ())
373- fltr = target .bool ()
374- loss [fltr ] = 1 - self .cal_logit_cone (predicted_vectors , cone_axes , cone_arcs )[fltr ]
375- loss [~ fltr ] = 1 - self .cal_logit_cone (predicted_vectors , negated_cone_axes ,
376- negated_cone_arcs )[~ fltr ]
377- return torch .mean (loss )
353+ memberships = 1 - in_cone_parts (predicted_vectors , cone_axes , cone_arcs )
354+ loss = torch .nn .functional .binary_cross_entropy (memberships , target )
355+ return loss
0 commit comments