22from tempfile import TemporaryDirectory
33import logging
44import random
5-
5+ from math import pi
66from torch import nn
77from torch .nn .utils .rnn import (
88 pack_padded_sequence ,
@@ -233,3 +233,145 @@ def forward(self, data):
233233 electra = self .electra (data )
234234 d = torch .sum (electra .last_hidden_state , dim = 1 )
235235 return dict (logits = self .output (d ), attentions = electra .attentions )
236+
237+ class ConeElectra (JCIBaseNet ):
238+ NAME = "ConeElectra"
239+
240+ def _get_data_and_labels (self , batch , batch_idx ):
241+ mask = pad_sequence (
242+ [torch .ones (l + 1 , device = self .device ) for l in batch .lens ],
243+ batch_first = True ,
244+ )
245+ cls_tokens = (
246+ torch .ones (batch .x .shape [0 ], dtype = torch .int , device = self .device ).unsqueeze (
247+ - 1
248+ )
249+ * CLS_TOKEN
250+ )
251+ return dict (
252+ features = torch .cat ((cls_tokens , batch .x ), dim = 1 ),
253+ labels = batch .y ,
254+ model_kwargs = dict (attention_mask = mask ),
255+ target_mask = batch .target_mask ,
256+ )
257+
258+ @property
259+ def as_pretrained (self ):
260+ return self .electra .electra
261+
262+ def __init__ (self , cone_dimensions = 20 , ** kwargs ):
263+ # Remove this property in order to prevent it from being stored as a
264+ # hyper parameter
265+ pretrained_checkpoint = (
266+ kwargs .pop ("pretrained_checkpoint" )
267+ if "pretrained_checkpoint" in kwargs
268+ else None
269+ )
270+
271+ self .cone_dimensions = cone_dimensions
272+
273+ super ().__init__ (** kwargs )
274+ if not "num_labels" in kwargs ["config" ] and self .out_dim is not None :
275+ kwargs ["config" ]["num_labels" ] = self .out_dim
276+ self .config = ElectraConfig (** kwargs ["config" ], output_attentions = True )
277+ self .word_dropout = nn .Dropout (kwargs ["config" ].get ("word_dropout" , 0 ))
278+ model_prefix = kwargs .get ("load_prefix" , None )
279+ if pretrained_checkpoint :
280+ with open (pretrained_checkpoint , "rb" ) as fin :
281+ model_dict = torch .load (fin ,map_location = self .device )
282+ if model_prefix :
283+ state_dict = {str (k )[len (model_prefix ):]:v for k ,v in model_dict ["state_dict" ].items () if str (k ).startswith (model_prefix )}
284+ else :
285+ state_dict = model_dict ["state_dict" ]
286+ self .electra = ElectraModel .from_pretrained (None , state_dict = state_dict , config = self .config )
287+ else :
288+ self .electra = ElectraModel (config = self .config )
289+
290+ in_d = self .config .hidden_size
291+
292+ self .line_embedding = nn .Sequential (
293+ nn .Dropout (self .config .hidden_dropout_prob ),
294+ nn .Linear (in_d , in_d ),
295+ nn .GELU (),
296+ nn .Dropout (self .config .hidden_dropout_prob ),
297+ nn .Linear (in_d , self .cone_dimensions ),
298+ )
299+
300+ self .cone_axes = nn .Parameter (2 * pi * torch .rand ((1 , self .config .num_labels , self .cone_dimensions )))
301+ self .cone_arcs = nn .Parameter (pi * (1 - 2 * torch .rand ((1 , self .config .num_labels , self .cone_dimensions ))))
302+
303+ def _get_data_for_loss (self , model_output , labels ):
304+ mask = model_output .get ("target_mask" )
305+ d = model_output ["predicted_vectors" ]
306+ return dict (input = dict (predicted_vectors = d ,
307+ cone_axes = self .cone_axes ,
308+ cone_arcs = self .cone_arcs ),
309+ target = labels .float ())
310+
311+ def _get_prediction_and_labels (self , data , labels , model_output ):
312+ mask = model_output .get ("target_mask" )
313+ d = model_output ["predicted_vectors" ]
314+
315+ d = 1 - ConeLoss .cal_logit_cone (d , self .cone_axes , self .cone_arcs )
316+
317+ return d , labels .int ()
318+
319+ def forward (self , data , ** kwargs ):
320+ self .batch_size = data ["features" ].shape [0 ]
321+ inp = self .electra .embeddings .forward (data ["features" ])
322+ inp = self .word_dropout (inp )
323+ electra = self .electra (inputs_embeds = inp , ** kwargs )
324+ d = electra .last_hidden_state [:, 0 , :]
325+ return dict (
326+ predicted_vectors = self .line_embedding (d ),
327+ attentions = electra .attentions ,
328+ target_mask = data .get ("target_mask" ),
329+ )
330+
331+ class ConeLoss :
332+
333+ def __init__ (self , center_scaling = 0.1 ):
334+ self .center_scaling = center_scaling
335+
336+ def negate (self , ax , arc ):
337+ offset = pi * torch .ones_like (ax )
338+ offset [ax >= 0 ] *= - 1
339+ return ax + offset , pi - arc
340+
341+ @classmethod
342+ def cal_logit_cone (cls , entity_embedding , query_axis_embedding , query_arg_embedding , center_scaling = 0.2 ):
343+ """Cone distance from https://github.com/MIRALab-USTC/QE-ConE
344+ :param entity_embedding:
345+ :param query_axis_embedding:
346+ :param query_arg_embedding:
347+ :return:
348+ """
349+
350+ e = entity_embedding .unsqueeze (1 )
351+
352+ distance2axis = torch .abs (torch .sin ((e - query_axis_embedding ) / 2 ))
353+ distance_base = torch .abs (torch .sin (query_arg_embedding / 2 ))
354+
355+ indicator_in = distance2axis < distance_base
356+ distance_out = torch .min (torch .abs (torch .sin (e - (query_axis_embedding - query_arg_embedding ) / 2 )), torch .abs (torch .sin (e - (query_axis_embedding + query_arg_embedding ) / 2 )))
357+ distance_out [indicator_in ] = 0.
358+
359+ distance_in = torch .min (distance2axis , distance_base )
360+
361+ distance = torch .norm (distance_out , p = 1 , dim = - 1 )/ e .shape [- 1 ] + center_scaling * torch .norm (distance_in , p = 1 , dim = - 1 )/ e .shape [- 1 ]
362+
363+ return distance
364+
365+ def __call__ (self , target , input ):
366+ cone_axes = input ["cone_axes" ]
367+ cone_arcs = input ["cone_arcs" ]
368+
369+ negated_cone_axes , negated_cone_arcs = self .negate (cone_arcs , cone_axes )
370+
371+ predicted_vectors = input ["predicted_vectors" ]
372+ loss = torch .zeros ((predicted_vectors .shape [0 ], cone_axes .shape [1 ]))
373+ fltr = target .bool ()
374+ loss [fltr ] = 1 - self .cal_logit_cone (predicted_vectors , cone_axes , cone_arcs )[fltr ]
375+ loss [~ fltr ] = 1 - self .cal_logit_cone (predicted_vectors , negated_cone_axes ,
376+ negated_cone_arcs )[~ fltr ]
377+ return torch .mean (loss )
0 commit comments