From cfbf02c9bd4dc476805dc1ddea666ac2c902fb17 Mon Sep 17 00:00:00 2001 From: Tianyang Zhao Date: Sat, 29 Jul 2023 15:44:16 -0700 Subject: [PATCH] Update med.py: Fixed the issue of BERTEncoder.forward() not returning cross-attentions when requested In class BertEncoder.forward() method, `all_cross_attentions` is defined in Line 409, but not maintained, which causes a retuning of None object when requested. In this revision, `all_cross_attentions` is properly updated and maintained in Line 461. The maintenance code is referred from the original Hugging-face Transformer library https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert/modeling_bert.py Line600, and is tested to be valid. --- models/med.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/med.py b/models/med.py index 7b00a354..fc83f6d2 100644 --- a/models/med.py +++ b/models/med.py @@ -458,6 +458,8 @@ def custom_forward(*inputs): next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,)