13
13
"DocPoolerBatchInput" ,
14
14
{
15
15
"embedding" : BatchInput ,
16
- "mask" : torch .Tensor , # shape: (batch_size, seq_len)
16
+ "mask" : torch .Tensor ,
17
17
"stats" : Dict [str , Any ],
18
18
},
19
19
)
20
20
21
21
DocPoolerBatchOutput = TypedDict (
22
22
"DocPoolerBatchOutput" ,
23
23
{
24
- "embeddings" : torch .Tensor , # shape: (batch_size, embedding_dim)
24
+ "embeddings" : torch .Tensor ,
25
25
},
26
26
)
27
27
@@ -51,13 +51,17 @@ def __init__(
51
51
name : str = "document_pooler" ,
52
52
* ,
53
53
embedding : WordEmbeddingComponent ,
54
- pooling_mode : Literal ["max" , "sum" , "mean" , "cls" ] = "mean" ,
54
+ pooling_mode : Literal ["max" , "sum" , "mean" , "cls" , "attention" ] = "mean" ,
55
55
):
56
56
super ().__init__ (nlp , name )
57
57
self .embedding = embedding
58
58
self .pooling_mode = pooling_mode
59
59
self .output_size = embedding .output_size
60
60
61
+ # Add attention layer if needed
62
+ if pooling_mode == "attention" :
63
+ self .attention = torch .nn .Linear (self .output_size , 1 )
64
+
61
65
def preprocess (self , doc : Doc , ** kwargs ) -> Dict [str , Any ]:
62
66
embedding_out = self .embedding .preprocess (doc , ** kwargs )
63
67
return {
@@ -76,26 +80,47 @@ def collate(self, batch: Dict[str, Any]) -> DocPoolerBatchInput:
76
80
}
77
81
78
82
def forward (self , batch : DocPoolerBatchInput ) -> DocPoolerBatchOutput :
79
- embeds = self .embedding (batch ["embedding" ])["embeddings" ]
83
+ """
84
+ Forward pass: compute document embeddings using the selected pooling strategy
85
+ """
86
+ embeds = self .embedding (batch ["embedding" ])["embeddings" ].refold (
87
+ "context" , "word"
88
+ )
80
89
device = embeds .device
81
90
82
91
if self .pooling_mode == "cls" :
83
92
pooled = self .embedding (batch ["embedding" ])["cls" ].to (device )
84
93
return {"embeddings" : pooled }
85
94
86
95
mask = embeds .mask
87
- mask_expanded = mask .unsqueeze (- 1 )
88
- masked_embeds = embeds * mask_expanded
89
- sum_embeds = masked_embeds .sum (dim = 1 )
90
- if self .pooling_mode == "mean" :
91
- valid_counts = mask .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
92
- pooled = sum_embeds / valid_counts
93
- elif self .pooling_mode == "max" :
94
- masked_embeds = embeds .masked_fill (~ mask_expanded , float ("-inf" ))
95
- pooled , _ = masked_embeds .max (dim = 1 )
96
- elif self .pooling_mode == "sum" :
97
- pooled = sum_embeds
96
+
97
+ if self .pooling_mode == "attention" :
98
+ attention_weights = self .attention (embeds ) # (batch_size, seq_len, 1)
99
+ attention_weights = attention_weights .squeeze (- 1 ) # (batch_size, seq_len)
100
+
101
+ attention_weights = attention_weights .masked_fill (~ mask , float ("-inf" ))
102
+
103
+ attention_weights = torch .softmax (attention_weights , dim = 1 )
104
+
105
+ attention_weights = attention_weights .unsqueeze (
106
+ - 1
107
+ ) # (batch_size, seq_len, 1)
108
+ pooled = (embeds * attention_weights ).sum (dim = 1 ) # (batch_size, embed_dim)
109
+
98
110
else :
99
- raise ValueError (f"Unknown pooling mode: { self .pooling_mode } " )
111
+ mask_expanded = mask .unsqueeze (- 1 )
112
+ masked_embeds = embeds * mask_expanded
113
+ sum_embeds = masked_embeds .sum (dim = 1 )
114
+
115
+ if self .pooling_mode == "mean" :
116
+ valid_counts = mask .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
117
+ pooled = sum_embeds / valid_counts
118
+ elif self .pooling_mode == "max" :
119
+ masked_embeds = embeds .masked_fill (~ mask_expanded , float ("-inf" ))
120
+ pooled , _ = masked_embeds .max (dim = 1 )
121
+ elif self .pooling_mode == "sum" :
122
+ pooled = sum_embeds
123
+ else :
124
+ raise ValueError (f"Unknown pooling mode: { self .pooling_mode } " )
100
125
101
126
return {"embeddings" : pooled }
0 commit comments