1+ import shutil
12from pathlib import Path
23from typing import cast
34
@@ -79,13 +80,17 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
7980 hookpoint_to_sparse_encode , _ = load_hooks_sparse_coders (model , run_cfg_gemma )
8081 # Define cache config and initialize cache
8182 log_path = Path .cwd () / "results" / "test" / "log"
83+ shutil .rmtree (log_path , ignore_errors = True )
8284 log_path .mkdir (parents = True , exist_ok = True )
8385
84- cache = LatentCache (
85- model ,
86- hookpoint_to_sparse_encode ,
87- batch_size = cache_cfg .batch_size ,
88- log_path = log_path ,
86+ cache , empty_cache = (
87+ LatentCache (
88+ model ,
89+ hookpoint_to_sparse_encode ,
90+ batch_size = cache_cfg .batch_size ,
91+ log_path = log_path ,
92+ )
93+ for _ in range (2 )
8994 )
9095
9196 # Generate mock tokens and run the cache
@@ -104,60 +109,9 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
104109 )
105110 return {
106111 "cache" : cache ,
112+ "empty_cache" : empty_cache ,
107113 "tokens" : tokens ,
108114 "cache_cfg" : cache_cfg ,
109115 "temp_dir" : temp_dir ,
110116 "firing_counts" : hookpoint_firing_counts ,
111117 }
112-
113-
114- def test_hookpoint_firing_counts_initialization (cache_setup ):
115- """
116- Ensure that hookpoint_firing_counts is initialized as an empty dictionary.
117- """
118- cache = cache_setup ["cache" ]
119- assert isinstance (cache .hookpoint_firing_counts , dict )
120- assert len (cache .hookpoint_firing_counts ) == 0 # Should be empty before run()
121-
122-
123- def test_hookpoint_firing_counts_updates (cache_setup ):
124- """
125- Ensure that hookpoint_firing_counts is properly updated after running the cache.
126- """
127- cache = cache_setup ["cache" ]
128- tokens = cache_setup ["tokens" ]
129- cache .run (cache_setup ["cache_cfg" ].n_tokens , tokens )
130-
131- assert (
132- len (cache .hookpoint_firing_counts ) > 0
133- ), "hookpoint_firing_counts should not be empty after run()"
134- for hookpoint , counts in cache .hookpoint_firing_counts .items ():
135- assert isinstance (
136- counts , torch .Tensor
137- ), f"Counts for { hookpoint } should be a torch.Tensor"
138- assert counts .ndim == 1 , f"Counts for { hookpoint } should be a 1D tensor"
139- assert (counts >= 0 ).all (), f"Counts for { hookpoint } should be non-negative"
140-
141-
142- def test_hookpoint_firing_counts_persistence (cache_setup ):
143- """
144- Ensure that hookpoint_firing_counts are correctly saved and loaded.
145- """
146- cache = cache_setup ["cache" ]
147- cache .save_firing_counts ()
148-
149- firing_counts_path = Path .cwd () / "results" / "log" / "hookpoint_firing_counts.pt"
150- assert firing_counts_path .exists (), "Firing counts file should exist after saving"
151-
152- loaded_counts = torch .load (firing_counts_path , weights_only = True )
153- assert isinstance (
154- loaded_counts , dict
155- ), "Loaded firing counts should be a dictionary"
156- assert (
157- loaded_counts .keys () == cache .hookpoint_firing_counts .keys ()
158- ), "Loaded firing counts keys should match saved keys"
159-
160- for hookpoint , counts in loaded_counts .items ():
161- assert torch .equal (
162- counts , cache .hookpoint_firing_counts [hookpoint ]
163- ), f"Mismatch in firing counts for { hookpoint } "
0 commit comments