Skip to content

Commit ac5f642

Browse files
authored
Add functionality to .display() (#140)
feat : better .display() function in latent record
1 parent b1b5e1b commit ac5f642

File tree

1 file changed

+147
-28
lines changed

1 file changed

+147
-28
lines changed

delphi/latents/latents.py

Lines changed: 147 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import dataclass, field
2-
from typing import NamedTuple, Optional
2+
from typing import Literal, NamedTuple, Optional
33

44
import blobfile as bf
55
import orjson
6+
import torch
67
from jaxtyping import Float, Int
78
from torch import Tensor
89
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -203,6 +204,8 @@ def display(
203204
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
204205
threshold: float = 0.0,
205206
n: int = 10,
207+
do_display: bool = True,
208+
example_source: Literal["examples", "train", "test"] = "examples",
206209
):
207210
"""
208211
Display the latent record in a formatted string.
@@ -216,9 +219,8 @@ def display(
216219
Returns:
217220
str: The formatted string.
218221
"""
219-
from IPython.core.display import HTML, display # type: ignore
220222

221-
def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
223+
def _to_string(toks, activations: Float[Tensor, "ctx_len"]) -> str:
222224
"""
223225
Convert tokens and activations to a string.
224226
@@ -229,28 +231,145 @@ def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
229231
Returns:
230232
str: The formatted string.
231233
"""
232-
result = []
233-
i = 0
234-
235-
max_act = activations.max()
236-
_threshold = max_act * threshold
237-
238-
while i < len(tokens):
239-
if activations[i] > _threshold:
240-
result.append("<mark>")
241-
while i < len(tokens) and activations[i] > _threshold:
242-
result.append(tokens[i])
243-
i += 1
244-
result.append("</mark>")
245-
else:
246-
result.append(tokens[i])
247-
i += 1
248-
return "".join(result)
249-
return ""
250-
251-
strings = [
252-
_to_string(tokenizer.batch_decode(example.tokens), example.activations)
253-
for example in self.examples[:n]
254-
]
255-
256-
display(HTML("<br><br>".join(strings)))
234+
text_spacing = "0.00em"
235+
toks = convert_token_array_to_list(toks)
236+
activations = convert_token_array_to_list(activations)
237+
inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
238+
toks = [
239+
[
240+
inverse_vocab[int(t)]
241+
.replace("Ġ", " ")
242+
.replace("▁", " ")
243+
.replace("\n", "\\n")
244+
for t in tok
245+
]
246+
for tok in toks
247+
]
248+
highlighted_text = []
249+
highlighted_text.append(
250+
"""
251+
<body style="background-color: black; color: white;">
252+
"""
253+
)
254+
max_value = max([max(activ) for activ in activations])
255+
min_value = min([min(activ) for activ in activations])
256+
# Add color bar
257+
highlighted_text.append(
258+
"Token Activations: " + make_colorbar(min_value, max_value)
259+
)
260+
261+
highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
262+
for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
263+
for act_ind, (a, t) in enumerate(zip(act, tok)):
264+
text_color, background_color = value_to_color(
265+
a, max_value, min_value
266+
)
267+
highlighted_text.append(
268+
f'<span style="background-color:{background_color};'
269+
f'margin-right: {text_spacing}; color:rgb({text_color})"'
270+
f">{escape(t)}</span>"
271+
) # noqa: E501
272+
highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
273+
highlighted_text = "".join(highlighted_text)
274+
return highlighted_text
275+
276+
match example_source:
277+
case "examples":
278+
examples = self.examples
279+
case "train":
280+
examples = self.train
281+
case "test":
282+
examples = [x[0] for x in self.test]
283+
case _:
284+
raise ValueError(f"Unknown example source: {example_source}")
285+
examples = examples[:n]
286+
strings = _to_string(
287+
[example.tokens for example in examples],
288+
[example.activations for example in examples],
289+
)
290+
291+
if do_display:
292+
from IPython.display import HTML, display
293+
294+
display(HTML(strings))
295+
else:
296+
return strings
297+
298+
299+
def make_colorbar(
300+
min_value,
301+
max_value,
302+
white=255,
303+
red_blue_ness=250,
304+
positive_threshold=0.01,
305+
negative_threshold=0.01,
306+
):
307+
# Add color bar
308+
colorbar = ""
309+
num_colors = 4
310+
if min_value < -negative_threshold:
311+
for i in range(num_colors, 0, -1):
312+
ratio = i / (num_colors)
313+
value = round((min_value * ratio), 1)
314+
text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
315+
colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>' # noqa: E501
316+
# Do zero
317+
colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>' # noqa: E501
318+
# Do positive
319+
if max_value > positive_threshold:
320+
for i in range(1, num_colors + 1):
321+
ratio = i / (num_colors)
322+
value = round((max_value * ratio), 1)
323+
text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
324+
colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>' # noqa: E501
325+
return colorbar
326+
327+
328+
def value_to_color(
329+
activation,
330+
max_value,
331+
min_value,
332+
white=255,
333+
red_blue_ness=250,
334+
positive_threshold=0.01,
335+
negative_threshold=0.01,
336+
):
337+
if activation > positive_threshold:
338+
ratio = activation / max_value
339+
text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
340+
background_color = f"rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)" # noqa: E501
341+
elif activation < -negative_threshold:
342+
ratio = activation / min_value
343+
text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
344+
background_color = f"rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)" # noqa: E501
345+
else:
346+
text_color = "0,0,0"
347+
background_color = f"rgba({white},{white},{white},1)"
348+
return text_color, background_color
349+
350+
351+
def convert_token_array_to_list(array):
352+
if isinstance(array, torch.Tensor):
353+
if array.dim() == 1:
354+
array = [array.tolist()]
355+
elif array.dim() == 2:
356+
array = array.tolist()
357+
else:
358+
raise NotImplementedError("tokens must be 1 or 2 dimensional")
359+
elif isinstance(array, list):
360+
# ensure it's a list of lists
361+
if isinstance(array[0], int):
362+
array = [array]
363+
if isinstance(array[0], torch.Tensor):
364+
array = [t.tolist() for t in array]
365+
return array
366+
367+
368+
def escape(t):
369+
t = (
370+
t.replace(" ", "&nbsp;")
371+
.replace("<bos>", "BOS")
372+
.replace("<", "&lt;")
373+
.replace(">", "&gt;")
374+
)
375+
return t

0 commit comments

Comments
 (0)