Skip to content

Commit 27f362e

Browse files
committed
Implement cone-based electra
1 parent bde515f commit 27f362e

File tree

2 files changed

+151
-1
lines changed

2 files changed

+151
-1
lines changed

chebai/experiments.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ def model_kwargs(self, *args) -> Dict:
283283
return d
284284

285285

286+
class ElectraConeOnTox21MoleculeNet(ElectraOnChEBI100):
287+
MODEL = electra.ConeElectra
288+
LOSS = electra.ConeLoss
289+
290+
@classmethod
291+
def identifier(cls) -> str:
292+
return "ElectraCone+Chebi100"
293+
286294
class ElectraOnTox21Challenge(_ElectraExperiment):
287295
@classmethod
288296
def identifier(cls) -> str:

chebai/models/electra.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from tempfile import TemporaryDirectory
33
import logging
44
import random
5-
5+
from math import pi
66
from torch import nn
77
from torch.nn.utils.rnn import (
88
pack_padded_sequence,
@@ -233,3 +233,145 @@ def forward(self, data):
233233
electra = self.electra(data)
234234
d = torch.sum(electra.last_hidden_state, dim=1)
235235
return dict(logits=self.output(d), attentions=electra.attentions)
236+
237+
class ConeElectra(JCIBaseNet):
238+
NAME = "ConeElectra"
239+
240+
def _get_data_and_labels(self, batch, batch_idx):
241+
mask = pad_sequence(
242+
[torch.ones(l + 1, device=self.device) for l in batch.lens],
243+
batch_first=True,
244+
)
245+
cls_tokens = (
246+
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
247+
-1
248+
)
249+
* CLS_TOKEN
250+
)
251+
return dict(
252+
features=torch.cat((cls_tokens, batch.x), dim=1),
253+
labels=batch.y,
254+
model_kwargs=dict(attention_mask=mask),
255+
target_mask=batch.target_mask,
256+
)
257+
258+
@property
259+
def as_pretrained(self):
260+
return self.electra.electra
261+
262+
def __init__(self, cone_dimensions=20, **kwargs):
263+
# Remove this property in order to prevent it from being stored as a
264+
# hyper parameter
265+
pretrained_checkpoint = (
266+
kwargs.pop("pretrained_checkpoint")
267+
if "pretrained_checkpoint" in kwargs
268+
else None
269+
)
270+
271+
self.cone_dimensions = cone_dimensions
272+
273+
super().__init__(**kwargs)
274+
if not "num_labels" in kwargs["config"] and self.out_dim is not None:
275+
kwargs["config"]["num_labels"] = self.out_dim
276+
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
277+
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
278+
model_prefix = kwargs.get("load_prefix", None)
279+
if pretrained_checkpoint:
280+
with open(pretrained_checkpoint, "rb") as fin:
281+
model_dict = torch.load(fin,map_location=self.device)
282+
if model_prefix:
283+
state_dict = {str(k)[len(model_prefix):]:v for k,v in model_dict["state_dict"].items() if str(k).startswith(model_prefix)}
284+
else:
285+
state_dict = model_dict["state_dict"]
286+
self.electra = ElectraModel.from_pretrained(None, state_dict=state_dict, config=self.config)
287+
else:
288+
self.electra = ElectraModel(config=self.config)
289+
290+
in_d = self.config.hidden_size
291+
292+
self.line_embedding = nn.Sequential(
293+
nn.Dropout(self.config.hidden_dropout_prob),
294+
nn.Linear(in_d, in_d),
295+
nn.GELU(),
296+
nn.Dropout(self.config.hidden_dropout_prob),
297+
nn.Linear(in_d, self.cone_dimensions),
298+
)
299+
300+
self.cone_axes = nn.Parameter(2*pi*torch.rand((1, self.config.num_labels, self.cone_dimensions)))
301+
self.cone_arcs = nn.Parameter(pi*(1-2*torch.rand((1, self.config.num_labels, self.cone_dimensions))))
302+
303+
def _get_data_for_loss(self, model_output, labels):
304+
mask = model_output.get("target_mask")
305+
d = model_output["predicted_vectors"]
306+
return dict(input=dict(predicted_vectors=d,
307+
cone_axes = self.cone_axes,
308+
cone_arcs = self.cone_arcs),
309+
target=labels.float())
310+
311+
def _get_prediction_and_labels(self, data, labels, model_output):
312+
mask = model_output.get("target_mask")
313+
d = model_output["predicted_vectors"]
314+
315+
d = 1- ConeLoss.cal_logit_cone(d, self.cone_axes, self.cone_arcs)
316+
317+
return d, labels.int()
318+
319+
def forward(self, data, **kwargs):
320+
self.batch_size = data["features"].shape[0]
321+
inp = self.electra.embeddings.forward(data["features"])
322+
inp = self.word_dropout(inp)
323+
electra = self.electra(inputs_embeds=inp, **kwargs)
324+
d = electra.last_hidden_state[:, 0, :]
325+
return dict(
326+
predicted_vectors=self.line_embedding(d),
327+
attentions=electra.attentions,
328+
target_mask=data.get("target_mask"),
329+
)
330+
331+
class ConeLoss:
332+
333+
def __init__(self, center_scaling=0.1):
334+
self.center_scaling = center_scaling
335+
336+
def negate(self, ax, arc):
337+
offset = pi*torch.ones_like(ax)
338+
offset[ax >= 0] *= -1
339+
return ax + offset, pi - arc
340+
341+
@classmethod
342+
def cal_logit_cone(cls, entity_embedding, query_axis_embedding, query_arg_embedding, center_scaling=0.2):
343+
"""Cone distance from https://github.com/MIRALab-USTC/QE-ConE
344+
:param entity_embedding:
345+
:param query_axis_embedding:
346+
:param query_arg_embedding:
347+
:return:
348+
"""
349+
350+
e = entity_embedding.unsqueeze(1)
351+
352+
distance2axis = torch.abs(torch.sin((e - query_axis_embedding) / 2))
353+
distance_base = torch.abs(torch.sin(query_arg_embedding / 2))
354+
355+
indicator_in = distance2axis < distance_base
356+
distance_out = torch.min(torch.abs(torch.sin(e - (query_axis_embedding - query_arg_embedding) / 2)), torch.abs(torch.sin(e - (query_axis_embedding + query_arg_embedding) / 2)))
357+
distance_out[indicator_in] = 0.
358+
359+
distance_in = torch.min(distance2axis, distance_base)
360+
361+
distance = torch.norm(distance_out, p=1, dim=-1)/e.shape[-1] + center_scaling * torch.norm(distance_in, p=1, dim=-1)/e.shape[-1]
362+
363+
return distance
364+
365+
def __call__(self, target, input):
366+
cone_axes = input["cone_axes"]
367+
cone_arcs = input["cone_arcs"]
368+
369+
negated_cone_axes, negated_cone_arcs = self.negate(cone_arcs, cone_axes)
370+
371+
predicted_vectors = input["predicted_vectors"]
372+
loss = torch.zeros((predicted_vectors.shape[0], cone_axes.shape[1]))
373+
fltr = target.bool()
374+
loss[fltr] = 1 - self.cal_logit_cone(predicted_vectors, cone_axes, cone_arcs)[fltr]
375+
loss[~fltr] = 1 - self.cal_logit_cone(predicted_vectors, negated_cone_axes,
376+
negated_cone_arcs)[~fltr]
377+
return torch.mean(loss)

0 commit comments

Comments
 (0)