@@ -310,11 +310,11 @@ def _get_data_for_loss(self, model_output, labels):
310310
311311 def _get_prediction_and_labels (self , data , labels , model_output ):
312312 mask = model_output .get ("target_mask" )
313- d = model_output ["predicted_vectors" ]
313+ d = model_output ["predicted_vectors" ]. unsqueeze ( 1 )
314314
315- d = 1 - ConeLoss . cal_logit_cone (d , self .cone_axes , self .cone_arcs )
315+ d = in_cone_parts (d , self .cone_axes , self .cone_arcs )
316316
317- return d , labels .int ()
317+ return torch . mean ( d , dim = - 1 ) , labels .int ()
318318
319319 def forward (self , data , ** kwargs ):
320320 self .batch_size = data ["features" ].shape [0 ]
@@ -331,10 +331,35 @@ def forward(self, data, **kwargs):
331331def softabs (x , eps = 0.01 ):
332332 return (x ** 2 + eps )** 0.5 - eps ** 0.5
333333
334+ def anglify (x ):
335+ return torch .tanh (x )* pi
336+
337+ def turn (vector , angle ):
338+ v = vector - angle
339+ return v - (v > pi )* 2 * pi + (v < - pi )* 2 * pi
340+
334341def in_cone_parts (vectors , cone_axes , cone_arcs ):
335- theta_L = cone_axes + cone_arcs
336- theta_R = cone_axes - cone_arcs
337- return ((softabs (vectors - theta_L ) + softabs (vectors - theta_R )) - cone_arcs )/ (2 * pi - cone_arcs )
342+
343+ """
344+ # trap between -pi and pi
345+ cone_ax_ang = anglify(cone_axes)
346+ v = anglify(vectors)
347+
348+ # trap between 0 and pi
349+ cone_arc_ang = (torch.tanh(cone_arcs)+1)*pi/2
350+ theta_L = cone_ax_ang - cone_arc_ang/2
351+ #theta_L = theta_L - (theta_L > 2*pi) * 2 * pi + (theta_L < 0) *2*pi
352+ theta_R = cone_ax_ang + cone_arc_ang/2
353+ #theta_R = theta_R - (theta_R > 2 * pi) * 2 * pi + (theta_R < 0) * 2 * pi
354+ dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang)
355+ return dis
356+ """
357+ a = cone_axes - cone_arcs ** 2
358+ b = cone_axes + cone_arcs ** 2
359+ bigger_than_a = torch .sigmoid (vectors - a )
360+ smaller_than_b = torch .sigmoid (b - vectors )
361+ return bigger_than_a * smaller_than_b
362+
338363
339364class ConeLoss :
340365
@@ -347,9 +372,9 @@ def negate(self, ax, arc):
347372 return ax + offset , pi - arc
348373
349374 def __call__ (self , target , input ):
350- predicted_vectors = input ["predicted_vectors" ]
375+ predicted_vectors = input ["predicted_vectors" ]. unsqueeze ( 1 )
351376 cone_axes = input ["cone_axes" ]
352377 cone_arcs = input ["cone_arcs" ]
353- memberships = 1 - in_cone_parts (predicted_vectors , cone_axes , cone_arcs )
354- loss = torch .nn .functional .binary_cross_entropy (memberships , target )
378+ memberships = ( 1 - 1e-6 ) * ( in_cone_parts (predicted_vectors , cone_axes , cone_arcs ) )
379+ loss = torch .nn .functional .binary_cross_entropy (memberships , target . unsqueeze ( - 1 ). expand ( - 1 , - 1 , 20 ) )
355380 return loss
0 commit comments