Skip to content

Commit 66fa7c4

Browse files
fixes EmbeddingScorer._prepare() passes arg of wrong type (#133)
- fix: fixes type hints in embedding and surprisal scorer. Also slightly changes types in ActivatingExample and propagates those changes
1 parent db49cb7 commit 66fa7c4

File tree

4 files changed

+77
-80
lines changed

4 files changed

+77
-80
lines changed

delphi/latents/constructors.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def prepare_non_activating_examples(
4747
NonActivatingExample(
4848
tokens=toks,
4949
activations=acts,
50-
normalized_activations=None,
5150
distance=distance,
5251
str_tokens=tokenizer.batch_decode(toks),
5352
)
@@ -281,7 +280,6 @@ def constructor(
281280
ActivatingExample(
282281
tokens=toks,
283282
activations=acts,
284-
normalized_activations=None,
285283
)
286284
for toks, acts in zip(token_windows, act_windows)
287285
]

delphi/latents/latents.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ class Example:
7575
activations: Float[Tensor, "ctx_len"]
7676
"""Activation values for the input sequence."""
7777

78-
str_tokens: list[str] | None = None
79-
"""Tokenized input sequence as strings."""
80-
81-
normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
82-
"""Activations quantized to integers in [0, 10]."""
83-
8478
@property
8579
def max_activation(self) -> float:
8680
"""
@@ -98,6 +92,12 @@ class ActivatingExample(Example):
9892
An example of a latent that activates a model.
9993
"""
10094

95+
normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
96+
"""Activations quantized to integers in [0, 10]."""
97+
98+
str_tokens: Optional[list[str]] = None
99+
"""Tokenized input sequence as strings."""
100+
101101
quantile: int = 0
102102
"""The quantile of the activating example."""
103103

@@ -108,6 +108,9 @@ class NonActivatingExample(Example):
108108
An example of a latent that does not activate a model.
109109
"""
110110

111+
str_tokens: list[str]
112+
"""Tokenized input sequence as strings."""
113+
111114
distance: float = 0.0
112115
"""
113116
The distance from the neighbouring latent.
@@ -125,7 +128,7 @@ class LatentRecord:
125128
"""The latent associated with the record."""
126129

127130
examples: list[ActivatingExample] = field(default_factory=list)
128-
"""Example sequences where the latent activations, assumed to be sorted in
131+
"""Example sequences where the latent activates, assumed to be sorted in
129132
descending order by max activation."""
130133

131134
not_active: list[NonActivatingExample] = field(default_factory=list)

delphi/scorers/embedding/embedding.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
22
import random
33
from dataclasses import dataclass
4-
from typing import NamedTuple
4+
from typing import NamedTuple, Sequence
55

6-
from transformers import PreTrainedTokenizer
6+
from delphi.latents.latents import ActivatingExample, NonActivatingExample
77

88
from ...latents import Example, LatentRecord
99
from ..scorer import Scorer, ScorerResult
@@ -33,56 +33,53 @@ class EmbeddingScorer(Scorer):
3333
def __init__(
3434
self,
3535
model,
36-
tokenizer: PreTrainedTokenizer | None = None,
3736
verbose: bool = False,
3837
**generation_kwargs,
3938
):
4039
self.model = model
4140
self.verbose = verbose
42-
self.tokenizer = tokenizer
4341
self.generation_kwargs = generation_kwargs
4442

45-
async def __call__( # type: ignore
46-
self, # type: ignore
47-
record: LatentRecord, # type: ignore
48-
) -> ScorerResult: # type: ignore
43+
async def __call__(
44+
self,
45+
record: LatentRecord,
46+
) -> ScorerResult:
4947
samples = self._prepare(record)
5048

5149
random.shuffle(samples)
5250
results = self._query(
5351
record.explanation,
54-
samples, # type: ignore
52+
samples,
5553
)
5654

5755
return ScorerResult(record=record, score=results)
5856

59-
def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]:
60-
return asyncio.run(self.__call__(record)) # type: ignore
57+
def call_sync(self, record: LatentRecord) -> ScorerResult:
58+
return asyncio.run(self.__call__(record))
6159

62-
def _prepare(self, record: LatentRecord) -> list[list[Sample]]:
60+
def _prepare(self, record: LatentRecord) -> list[Sample]:
6361
"""
6462
Prepare and shuffle a list of samples for classification.
6563
"""
64+
samples = []
65+
66+
assert (
67+
record.extra_examples is not None
68+
), "Extra (non-activating) examples need to be provided"
6669

67-
defaults = {
68-
"tokenizer": self.tokenizer,
69-
}
70-
samples = examples_to_samples(
71-
record.extra_examples, # type: ignore
72-
distance=-1,
73-
**defaults, # type: ignore
70+
samples.extend(
71+
examples_to_samples(
72+
record.extra_examples,
73+
)
7474
)
7575

76-
for i, examples in enumerate(record.test):
77-
samples.extend(
78-
examples_to_samples(
79-
examples, # type: ignore
80-
distance=i + 1,
81-
**defaults, # type: ignore
82-
)
76+
samples.extend(
77+
examples_to_samples(
78+
record.test,
8379
)
80+
)
8481

85-
return samples # type: ignore
82+
return samples
8683

8784
def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutput]:
8885
explanation_string = (
@@ -93,38 +90,39 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutpu
9390
query_embeding = self.model.encode(explanation_prompt)
9491
samples_text = [sample.text for sample in samples]
9592

96-
# # Temporary batching
97-
# sample_embedings = []
98-
# for i in range(0, len(samples_text), 10):
99-
# sample_embedings.extend(self.model.encode(samples_text[i:i+10]))
10093
sample_embedings = self.model.encode(samples_text)
10194
similarity = self.model.similarity(query_embeding, sample_embedings)[0]
10295

10396
results = []
10497
for i in range(len(samples)):
105-
# print(i)
10698
samples[i].data.similarity = similarity[i].item()
10799
results.append(samples[i].data)
108100
return results
109101

110102

111103
def examples_to_samples(
112-
examples: list[Example],
113-
tokenizer: PreTrainedTokenizer,
114-
**sample_kwargs,
104+
examples: Sequence[Example],
115105
) -> list[Sample]:
116106
samples = []
117107
for example in examples:
118-
if tokenizer is not None:
119-
text = "".join(tokenizer.batch_decode(example.tokens))
120-
else:
121-
text = "".join(example.tokens)
108+
assert isinstance(example, ActivatingExample) or isinstance(
109+
example, NonActivatingExample
110+
)
111+
assert example.str_tokens is not None
112+
text = "".join(str(token) for token in example.str_tokens)
122113
activations = example.activations.tolist()
123114
samples.append(
124115
Sample(
125116
text=text,
126117
activations=activations,
127-
data=EmbeddingOutput(text=text, **sample_kwargs),
118+
data=EmbeddingOutput(
119+
text=text,
120+
distance=(
121+
example.quantile
122+
if isinstance(example, ActivatingExample)
123+
else example.distance
124+
),
125+
),
128126
)
129127
)
130128

delphi/scorers/surprisal/surprisal.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import random
22
from dataclasses import dataclass
3-
from typing import NamedTuple
3+
from typing import NamedTuple, Sequence
44

55
import torch
66
from simple_parsing import field
77
from torch.nn.functional import cross_entropy
8-
from transformers import PreTrainedTokenizer
98

10-
from delphi.utils import assert_type
11-
12-
from ...latents import ActivatingExample, Example, LatentRecord
9+
from ...latents import (
10+
ActivatingExample,
11+
Example,
12+
LatentRecord,
13+
NonActivatingExample,
14+
)
1315
from ..scorer import Scorer, ScorerResult
1416
from .prompts import BASEPROMPT as base_prompt
1517

@@ -44,21 +46,19 @@ class SurprisalScorer(Scorer):
4446
def __init__(
4547
self,
4648
model,
47-
tokenizer,
4849
verbose: bool,
4950
batch_size: int,
5051
**generation_kwargs,
5152
):
5253
self.model = model
5354
self.verbose = verbose
54-
self.tokenizer = tokenizer
5555
self.batch_size = batch_size
5656
self.generation_kwargs = generation_kwargs
5757

58-
async def __call__( # type: ignore
59-
self, # type: ignore
60-
record: LatentRecord, # type: ignore
61-
) -> ScorerResult: # type: ignore
58+
async def __call__(
59+
self,
60+
record: LatentRecord,
61+
) -> ScorerResult:
6262
samples = self._prepare(record)
6363

6464
random.shuffle(samples)
@@ -74,35 +74,25 @@ def _prepare(self, record: LatentRecord) -> list[Sample]:
7474
Prepare and shuffle a list of samples for classification.
7575
"""
7676

77-
defaults = {
78-
"tokenizer": self.tokenizer,
79-
}
80-
8177
assert record.extra_examples is not None, "No extra examples provided"
8278
samples = examples_to_samples(
8379
record.extra_examples,
84-
distance=-1,
85-
**defaults,
8680
)
8781

88-
for i, examples in enumerate(record.test):
89-
examples = assert_type(list, examples)
90-
samples.extend(
91-
examples_to_samples(
92-
examples,
93-
distance=i + 1,
94-
**defaults,
95-
)
82+
samples.extend(
83+
examples_to_samples(
84+
record.test,
9685
)
86+
)
9787

9888
return samples
9989

10090
def compute_loss_with_kv_cache(
10191
self, explanation: str, samples: list[Sample], batch_size=2
10292
):
103-
# print(explanation_prompt)
10493
model = self.model
10594
tokenizer = self.model.tokenizer
95+
assert tokenizer is not None, "Tokenizer is not set in model.tokenizer"
10696
# Tokenize explanation
10797
tokenizer.padding_side = "right"
10898
tokenizer.pad_token = tokenizer.eos_token
@@ -187,20 +177,28 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[SurprisalOutpu
187177

188178

189179
def examples_to_samples(
190-
examples: list[Example] | list[ActivatingExample],
191-
tokenizer: PreTrainedTokenizer,
192-
**sample_kwargs,
180+
examples: Sequence[Example],
193181
) -> list[Sample]:
194182
samples = []
195183
for example in examples:
196-
text = "".join(tokenizer.batch_decode(example.tokens))
184+
assert isinstance(example, ActivatingExample) or isinstance(
185+
example, NonActivatingExample
186+
)
187+
assert example.str_tokens is not None
188+
text = "".join(str(token) for token in example.str_tokens)
197189
activations = example.activations.tolist()
198190
samples.append(
199191
Sample(
200192
text=text,
201193
activations=activations,
202194
data=SurprisalOutput(
203-
activations=activations, text=text, **sample_kwargs
195+
activations=activations,
196+
text=text,
197+
distance=(
198+
example.quantile
199+
if isinstance(example, ActivatingExample)
200+
else example.distance
201+
),
204202
),
205203
)
206204
)

0 commit comments

Comments
 (0)