11from dataclasses import dataclass , field
2- from typing import NamedTuple , Optional
2+ from typing import Literal , NamedTuple , Optional
33
44import blobfile as bf
55import orjson
6+ import torch
67from jaxtyping import Float , Int
78from torch import Tensor
89from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
@@ -203,6 +204,8 @@ def display(
203204 tokenizer : PreTrainedTokenizer | PreTrainedTokenizerFast ,
204205 threshold : float = 0.0 ,
205206 n : int = 10 ,
207+ do_display : bool = True ,
208+ example_source : Literal ["examples" , "train" , "test" ] = "examples" ,
206209 ):
207210 """
208211 Display the latent record in a formatted string.
@@ -216,9 +219,8 @@ def display(
216219 Returns:
217220 str: The formatted string.
218221 """
219- from IPython .core .display import HTML , display # type: ignore
220222
221- def _to_string (tokens : list [ str ] , activations : Float [Tensor , "ctx_len" ]) -> str :
223+ def _to_string (toks , activations : Float [Tensor , "ctx_len" ]) -> str :
222224 """
223225 Convert tokens and activations to a string.
224226
@@ -229,28 +231,145 @@ def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
229231 Returns:
230232 str: The formatted string.
231233 """
232- result = []
233- i = 0
234-
235- max_act = activations .max ()
236- _threshold = max_act * threshold
237-
238- while i < len (tokens ):
239- if activations [i ] > _threshold :
240- result .append ("<mark>" )
241- while i < len (tokens ) and activations [i ] > _threshold :
242- result .append (tokens [i ])
243- i += 1
244- result .append ("</mark>" )
245- else :
246- result .append (tokens [i ])
247- i += 1
248- return "" .join (result )
249- return ""
250-
251- strings = [
252- _to_string (tokenizer .batch_decode (example .tokens ), example .activations )
253- for example in self .examples [:n ]
254- ]
255-
256- display (HTML ("<br><br>" .join (strings )))
234+ text_spacing = "0.00em"
235+ toks = convert_token_array_to_list (toks )
236+ activations = convert_token_array_to_list (activations )
237+ inverse_vocab = {v : k for k , v in tokenizer .vocab .items ()}
238+ toks = [
239+ [
240+ inverse_vocab [int (t )]
241+ .replace ("Ġ" , " " )
242+ .replace ("▁" , " " )
243+ .replace ("\n " , "\\ n" )
244+ for t in tok
245+ ]
246+ for tok in toks
247+ ]
248+ highlighted_text = []
249+ highlighted_text .append (
250+ """
251+ <body style="background-color: black; color: white;">
252+ """
253+ )
254+ max_value = max ([max (activ ) for activ in activations ])
255+ min_value = min ([min (activ ) for activ in activations ])
256+ # Add color bar
257+ highlighted_text .append (
258+ "Token Activations: " + make_colorbar (min_value , max_value )
259+ )
260+
261+ highlighted_text .append ('<div style="margin-top: 0.5em;"></div>' )
262+ for seq_ind , (act , tok ) in enumerate (zip (activations , toks )):
263+ for act_ind , (a , t ) in enumerate (zip (act , tok )):
264+ text_color , background_color = value_to_color (
265+ a , max_value , min_value
266+ )
267+ highlighted_text .append (
268+ f'<span style="background-color:{ background_color } ;'
269+ f'margin-right: { text_spacing } ; color:rgb({ text_color } )"'
270+ f">{ escape (t )} </span>"
271+ ) # noqa: E501
272+ highlighted_text .append ('<div style="margin-top: 0.2em;"></div>' )
273+ highlighted_text = "" .join (highlighted_text )
274+ return highlighted_text
275+
276+ match example_source :
277+ case "examples" :
278+ examples = self .examples
279+ case "train" :
280+ examples = self .train
281+ case "test" :
282+ examples = [x [0 ] for x in self .test ]
283+ case _:
284+ raise ValueError (f"Unknown example source: { example_source } " )
285+ examples = examples [:n ]
286+ strings = _to_string (
287+ [example .tokens for example in examples ],
288+ [example .activations for example in examples ],
289+ )
290+
291+ if do_display :
292+ from IPython .display import HTML , display
293+
294+ display (HTML (strings ))
295+ else :
296+ return strings
297+
298+
299+ def make_colorbar (
300+ min_value ,
301+ max_value ,
302+ white = 255 ,
303+ red_blue_ness = 250 ,
304+ positive_threshold = 0.01 ,
305+ negative_threshold = 0.01 ,
306+ ):
307+ # Add color bar
308+ colorbar = ""
309+ num_colors = 4
310+ if min_value < - negative_threshold :
311+ for i in range (num_colors , 0 , - 1 ):
312+ ratio = i / (num_colors )
313+ value = round ((min_value * ratio ), 1 )
314+ text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
315+ colorbar += f'<span style="background-color:rgba(255, { int (red_blue_ness - (red_blue_ness * ratio ))} ,{ int (red_blue_ness - (red_blue_ness * ratio ))} ,1); color:rgb({ text_color } )"> { value }  </span>' # noqa: E501
316+ # Do zero
317+ colorbar += f'<span style="background-color:rgba({ white } ,{ white } ,{ white } ,1);color:rgb(0,0,0)"> 0.0 </span>' # noqa: E501
318+ # Do positive
319+ if max_value > positive_threshold :
320+ for i in range (1 , num_colors + 1 ):
321+ ratio = i / (num_colors )
322+ value = round ((max_value * ratio ), 1 )
323+ text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
324+ colorbar += f'<span style="background-color:rgba({ int (red_blue_ness - (red_blue_ness * ratio ))} ,{ int (red_blue_ness - (red_blue_ness * ratio ))} ,255,1);color:rgb({ text_color } )"> { value }  </span>' # noqa: E501
325+ return colorbar
326+
327+
328+ def value_to_color (
329+ activation ,
330+ max_value ,
331+ min_value ,
332+ white = 255 ,
333+ red_blue_ness = 250 ,
334+ positive_threshold = 0.01 ,
335+ negative_threshold = 0.01 ,
336+ ):
337+ if activation > positive_threshold :
338+ ratio = activation / max_value
339+ text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
340+ background_color = f"rgba({ int (red_blue_ness - (red_blue_ness * ratio ))} ,{ int (red_blue_ness - (red_blue_ness * ratio ))} ,255,1)" # noqa: E501
341+ elif activation < - negative_threshold :
342+ ratio = activation / min_value
343+ text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
344+ background_color = f"rgba(255, { int (red_blue_ness - (red_blue_ness * ratio ))} ,{ int (red_blue_ness - (red_blue_ness * ratio ))} ,1)" # noqa: E501
345+ else :
346+ text_color = "0,0,0"
347+ background_color = f"rgba({ white } ,{ white } ,{ white } ,1)"
348+ return text_color , background_color
349+
350+
351+ def convert_token_array_to_list (array ):
352+ if isinstance (array , torch .Tensor ):
353+ if array .dim () == 1 :
354+ array = [array .tolist ()]
355+ elif array .dim () == 2 :
356+ array = array .tolist ()
357+ else :
358+ raise NotImplementedError ("tokens must be 1 or 2 dimensional" )
359+ elif isinstance (array , list ):
360+ # ensure it's a list of lists
361+ if isinstance (array [0 ], int ):
362+ array = [array ]
363+ if isinstance (array [0 ], torch .Tensor ):
364+ array = [t .tolist () for t in array ]
365+ return array
366+
367+
368+ def escape (t ):
369+ t = (
370+ t .replace (" " , " " )
371+ .replace ("<bos>" , "BOS" )
372+ .replace ("<" , "<" )
373+ .replace (">" , ">" )
374+ )
375+ return t
0 commit comments