Skip to content

Commit d774567

Browse files
committed
Implement box loss
1 parent 7cdfa64 commit d774567

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

chebai/models/electra.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ def _get_data_for_loss(self, model_output, labels):
310310

311311
def _get_prediction_and_labels(self, data, labels, model_output):
312312
mask = model_output.get("target_mask")
313-
d = model_output["predicted_vectors"]
313+
d = model_output["predicted_vectors"].unsqueeze(1)
314314

315-
d = 1- ConeLoss.cal_logit_cone(d, self.cone_axes, self.cone_arcs)
315+
d = in_cone_parts(d, self.cone_axes, self.cone_arcs)
316316

317-
return d, labels.int()
317+
return torch.mean(d, dim=-1), labels.int()
318318

319319
def forward(self, data, **kwargs):
320320
self.batch_size = data["features"].shape[0]
@@ -331,10 +331,35 @@ def forward(self, data, **kwargs):
331331
def softabs(x, eps=0.01):
332332
return (x**2+eps)**0.5-eps**0.5
333333

334+
def anglify(x):
335+
return torch.tanh(x)*pi
336+
337+
def turn(vector, angle):
338+
v = vector - angle
339+
return v - (v > pi)*2*pi + (v< -pi)*2*pi
340+
334341
def in_cone_parts(vectors, cone_axes, cone_arcs):
335-
theta_L = cone_axes + cone_arcs
336-
theta_R = cone_axes - cone_arcs
337-
return ((softabs(vectors - theta_L) + softabs(vectors - theta_R)) - cone_arcs)/(2*pi-cone_arcs)
342+
343+
"""
344+
# trap between -pi and pi
345+
cone_ax_ang = anglify(cone_axes)
346+
v = anglify(vectors)
347+
348+
# trap between 0 and pi
349+
cone_arc_ang = (torch.tanh(cone_arcs)+1)*pi/2
350+
theta_L = cone_ax_ang - cone_arc_ang/2
351+
#theta_L = theta_L - (theta_L > 2*pi) * 2 * pi + (theta_L < 0) *2*pi
352+
theta_R = cone_ax_ang + cone_arc_ang/2
353+
#theta_R = theta_R - (theta_R > 2 * pi) * 2 * pi + (theta_R < 0) * 2 * pi
354+
dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang)
355+
return dis
356+
"""
357+
a = cone_axes - cone_arcs**2
358+
b = cone_axes + cone_arcs**2
359+
bigger_than_a = torch.sigmoid(vectors-a)
360+
smaller_than_b = torch.sigmoid(b-vectors)
361+
return bigger_than_a * smaller_than_b
362+
338363

339364
class ConeLoss:
340365

@@ -347,9 +372,9 @@ def negate(self, ax, arc):
347372
return ax + offset, pi - arc
348373

349374
def __call__(self, target, input):
350-
predicted_vectors = input["predicted_vectors"]
375+
predicted_vectors = input["predicted_vectors"].unsqueeze(1)
351376
cone_axes = input["cone_axes"]
352377
cone_arcs = input["cone_arcs"]
353-
memberships = 1 - in_cone_parts(predicted_vectors, cone_axes, cone_arcs)
354-
loss = torch.nn.functional.binary_cross_entropy(memberships, target)
378+
memberships = (1-1e-6)*(in_cone_parts(predicted_vectors, cone_axes, cone_arcs))
379+
loss = torch.nn.functional.binary_cross_entropy(memberships, target.unsqueeze(-1).expand(-1,-1,20))
355380
return loss

0 commit comments

Comments
 (0)