Skip to content

Commit 7cdfa64

Browse files
committed
Implement custom membershio function
1 parent 84f8b7e commit 7cdfa64

File tree

1 file changed

+12
-34
lines changed

1 file changed

+12
-34
lines changed

chebai/models/electra.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,14 @@ def forward(self, data, **kwargs):
328328
target_mask=data.get("target_mask"),
329329
)
330330

331+
def softabs(x, eps=0.01):
332+
return (x**2+eps)**0.5-eps**0.5
333+
334+
def in_cone_parts(vectors, cone_axes, cone_arcs):
335+
theta_L = cone_axes + cone_arcs
336+
theta_R = cone_axes - cone_arcs
337+
return ((softabs(vectors - theta_L) + softabs(vectors - theta_R)) - cone_arcs)/(2*pi-cone_arcs)
338+
331339
class ConeLoss:
332340

333341
def __init__(self, center_scaling=0.1):
@@ -338,40 +346,10 @@ def negate(self, ax, arc):
338346
offset[ax >= 0] *= -1
339347
return ax + offset, pi - arc
340348

341-
@classmethod
342-
def cal_logit_cone(cls, entity_embedding, query_axis_embedding, query_arg_embedding, center_scaling=0.2):
343-
"""Cone distance from https://github.com/MIRALab-USTC/QE-ConE
344-
:param entity_embedding:
345-
:param query_axis_embedding:
346-
:param query_arg_embedding:
347-
:return:
348-
"""
349-
350-
e = entity_embedding.unsqueeze(1)
351-
352-
distance2axis = torch.abs(torch.sin((e - query_axis_embedding) / 2))
353-
distance_base = torch.abs(torch.sin(query_arg_embedding / 2))
354-
355-
indicator_in = distance2axis < distance_base
356-
distance_out = torch.min(torch.abs(torch.sin(e - (query_axis_embedding - query_arg_embedding) / 2)), torch.abs(torch.sin(e - (query_axis_embedding + query_arg_embedding) / 2)))
357-
distance_out[indicator_in] = 0.
358-
359-
distance_in = torch.min(distance2axis, distance_base)
360-
361-
distance = torch.norm(distance_out, p=1, dim=-1)/e.shape[-1] + center_scaling * torch.norm(distance_in, p=1, dim=-1)/e.shape[-1]
362-
363-
return distance
364-
365349
def __call__(self, target, input):
350+
predicted_vectors = input["predicted_vectors"]
366351
cone_axes = input["cone_axes"]
367352
cone_arcs = input["cone_arcs"]
368-
369-
negated_cone_axes, negated_cone_arcs = self.negate(cone_arcs, cone_axes)
370-
371-
predicted_vectors = input["predicted_vectors"]
372-
loss = torch.zeros((predicted_vectors.shape[0], cone_axes.shape[1]), device=target.get_device())
373-
fltr = target.bool()
374-
loss[fltr] = 1 - self.cal_logit_cone(predicted_vectors, cone_axes, cone_arcs)[fltr]
375-
loss[~fltr] = 1 - self.cal_logit_cone(predicted_vectors, negated_cone_axes,
376-
negated_cone_arcs)[~fltr]
377-
return torch.mean(loss)
353+
memberships = 1 - in_cone_parts(predicted_vectors, cone_axes, cone_arcs)
354+
loss = torch.nn.functional.binary_cross_entropy(memberships, target)
355+
return loss

0 commit comments

Comments
 (0)