Skip to content

Commit 8c6d78d

Browse files
committed
add attention_pooling
1 parent 7ca9544 commit 8c6d78d

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
"DocPoolerBatchInput",
1414
{
1515
"embedding": BatchInput,
16-
"mask": torch.Tensor, # shape: (batch_size, seq_len)
16+
"mask": torch.Tensor,
1717
"stats": Dict[str, Any],
1818
},
1919
)
2020

2121
DocPoolerBatchOutput = TypedDict(
2222
"DocPoolerBatchOutput",
2323
{
24-
"embeddings": torch.Tensor, # shape: (batch_size, embedding_dim)
24+
"embeddings": torch.Tensor,
2525
},
2626
)
2727

@@ -51,13 +51,17 @@ def __init__(
5151
name: str = "document_pooler",
5252
*,
5353
embedding: WordEmbeddingComponent,
54-
pooling_mode: Literal["max", "sum", "mean", "cls"] = "mean",
54+
pooling_mode: Literal["max", "sum", "mean", "cls", "attention"] = "mean",
5555
):
5656
super().__init__(nlp, name)
5757
self.embedding = embedding
5858
self.pooling_mode = pooling_mode
5959
self.output_size = embedding.output_size
6060

61+
# Add attention layer if needed
62+
if pooling_mode == "attention":
63+
self.attention = torch.nn.Linear(self.output_size, 1)
64+
6165
def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]:
6266
embedding_out = self.embedding.preprocess(doc, **kwargs)
6367
return {
@@ -76,26 +80,47 @@ def collate(self, batch: Dict[str, Any]) -> DocPoolerBatchInput:
7680
}
7781

7882
def forward(self, batch: DocPoolerBatchInput) -> DocPoolerBatchOutput:
79-
embeds = self.embedding(batch["embedding"])["embeddings"]
83+
"""
84+
Forward pass: compute document embeddings using the selected pooling strategy
85+
"""
86+
embeds = self.embedding(batch["embedding"])["embeddings"].refold(
87+
"context", "word"
88+
)
8089
device = embeds.device
8190

8291
if self.pooling_mode == "cls":
8392
pooled = self.embedding(batch["embedding"])["cls"].to(device)
8493
return {"embeddings": pooled}
8594

8695
mask = embeds.mask
87-
mask_expanded = mask.unsqueeze(-1)
88-
masked_embeds = embeds * mask_expanded
89-
sum_embeds = masked_embeds.sum(dim=1)
90-
if self.pooling_mode == "mean":
91-
valid_counts = mask.sum(dim=1, keepdim=True).clamp(min=1)
92-
pooled = sum_embeds / valid_counts
93-
elif self.pooling_mode == "max":
94-
masked_embeds = embeds.masked_fill(~mask_expanded, float("-inf"))
95-
pooled, _ = masked_embeds.max(dim=1)
96-
elif self.pooling_mode == "sum":
97-
pooled = sum_embeds
96+
97+
if self.pooling_mode == "attention":
98+
attention_weights = self.attention(embeds) # (batch_size, seq_len, 1)
99+
attention_weights = attention_weights.squeeze(-1) # (batch_size, seq_len)
100+
101+
attention_weights = attention_weights.masked_fill(~mask, float("-inf"))
102+
103+
attention_weights = torch.softmax(attention_weights, dim=1)
104+
105+
attention_weights = attention_weights.unsqueeze(
106+
-1
107+
) # (batch_size, seq_len, 1)
108+
pooled = (embeds * attention_weights).sum(dim=1) # (batch_size, embed_dim)
109+
98110
else:
99-
raise ValueError(f"Unknown pooling mode: {self.pooling_mode}")
111+
mask_expanded = mask.unsqueeze(-1)
112+
masked_embeds = embeds * mask_expanded
113+
sum_embeds = masked_embeds.sum(dim=1)
114+
115+
if self.pooling_mode == "mean":
116+
valid_counts = mask.sum(dim=1, keepdim=True).clamp(min=1)
117+
pooled = sum_embeds / valid_counts
118+
elif self.pooling_mode == "max":
119+
masked_embeds = embeds.masked_fill(~mask_expanded, float("-inf"))
120+
pooled, _ = masked_embeds.max(dim=1)
121+
elif self.pooling_mode == "sum":
122+
pooled = sum_embeds
123+
else:
124+
raise ValueError(f"Unknown pooling mode: {self.pooling_mode}")
100125

101126
return {"embeddings": pooled}

0 commit comments

Comments
 (0)