Skip to content

Commit 431d654

Browse files
anthonyduong9SrGonaopre-commit-ci[bot]
authored
feat: replaces print with logging (#136)
* feat: replaces print with logging * Change some infos to warnings in constructor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Goncalo Paulo <30472805+SrGonao@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 66fa7c4 commit 431d654

File tree

17 files changed

+82
-65
lines changed

17 files changed

+82
-65
lines changed

__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This logger is needed for running tests from the repo directory.
2+
# The actual package logger is in delphi/delphi/__init__.py
3+
import logging
4+
5+
logger = logging.getLogger(__name__)

delphi/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
__version__ = "0.0.2"
2+
3+
import logging
4+
5+
logger = logging.getLogger(__name__)

delphi/__main__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import os
34
from functools import partial
45
from pathlib import Path
@@ -17,6 +18,7 @@
1718
PreTrainedTokenizerFast,
1819
)
1920

21+
from delphi import logger
2022
from delphi.clients import Offline, OpenRouter
2123
from delphi.config import RunConfig
2224
from delphi.explainers import ContrastiveExplainer, DefaultExplainer, NoOpExplainer
@@ -450,6 +452,16 @@ async def run(
450452

451453

452454
if __name__ == "__main__":
455+
# Configure logging for CLI usage
456+
logger.setLevel(logging.INFO)
457+
file_handler = logging.FileHandler("delphi.log")
458+
file_handler.setLevel(logging.INFO)
459+
formatter = logging.Formatter(
460+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
461+
)
462+
file_handler.setFormatter(formatter)
463+
logger.addHandler(file_handler)
464+
453465
parser = ArgumentParser()
454466
parser.add_arguments(RunConfig, dest="run_cfg")
455467
args = parser.parse_args()

delphi/clients/offline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
destroy_model_parallel,
1313
)
1414

15-
from ..logger import logger
15+
from delphi import logger
16+
1617
from .client import Client, Response
1718

1819

delphi/clients/openrouter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import httpx
55

6-
from ..logger import logger
6+
from delphi import logger
7+
78
from .client import Client, Response
89
from .types import ChatFormatRequest
910

@@ -30,7 +31,7 @@ def __init__(
3031
self.temperature = temperature
3132
timeout_config = httpx.Timeout(5.0)
3233
self.client = httpx.AsyncClient(timeout=timeout_config)
33-
print("WARNING: We currently don't support logprobs for OpenRouter")
34+
logger.warning("We currently don't support logprobs for OpenRouter")
3435

3536
def postprocess(self, response):
3637
response_json = response.json()

delphi/explainers/contrastive_explainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ async def __call__(self, record: LatentRecord) -> ExplainerResult:
6060
response_text = response
6161
explanation = self.parse_explanation(response_text)
6262
if self.verbose:
63-
from ..logger import logger
63+
from delphi import logger
6464

6565
logger.info(f"Explanation: {explanation}")
6666
logger.info(f"Messages: {messages[-1]['content']}")
6767
logger.info(f"Response: {response}")
6868

6969
return ExplainerResult(record=record, explanation=explanation)
7070
except Exception as e:
71-
from ..logger import logger
71+
from delphi import logger
7272

7373
logger.error(f"Explanation parsing failed: {repr(e)}")
7474
return ExplainerResult(

delphi/explainers/explainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import aiofiles
1010

11+
from delphi import logger
12+
1113
from ..clients.client import Client, Response
1214
from ..latents.latents import ActivatingExample, LatentRecord
13-
from ..logger import logger
1415

1516

1617
class ExplainerResult(NamedTuple):
@@ -127,7 +128,7 @@ async def explanation_loader(
127128
explanation = json.loads(await f.read())
128129
return ExplainerResult(record=record, explanation=explanation)
129130
except FileNotFoundError:
130-
print(f"No explanation found for {record.latent}")
131+
logger.info(f"No explanation found for {record.latent}")
131132
return ExplainerResult(record=record, explanation="No explanation found")
132133

133134

delphi/latents/cache.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tqdm import tqdm
1313
from transformers import PreTrainedModel
1414

15+
from delphi import logger
1516
from delphi.config import CacheConfig
1617
from delphi.latents.collect_activations import collect_activations
1718

@@ -298,7 +299,7 @@ def run(self, n_tokens: int, tokens: token_tensor_type):
298299
pbar.update(1)
299300
pbar.set_postfix({"Total Tokens": f"{total_tokens:,}"})
300301

301-
print(f"Total tokens processed: {total_tokens:,}")
302+
logger.info(f"Total tokens processed: {total_tokens:,}")
302303
self.cache.save()
303304
self.save_firing_counts()
304305

@@ -374,8 +375,8 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True):
374375
masked_locations = masked_locations.astype(np.uint16)
375376
else:
376377
masked_locations = masked_locations.astype(np.uint32)
377-
print(
378-
"Warning: Increasing the number of splits might reduce the"
378+
logger.warning(
379+
"Increasing the number of splits might reduce the"
379380
"memory usage of the cache."
380381
)
381382

@@ -399,10 +400,10 @@ def generate_statistics_cache(self):
399400
to the console.
400401
"""
401402
assert self.width is not None, "Width must be set before generating statistics"
402-
print("Feature statistics:")
403+
logger.info("Feature statistics:")
403404
# Token frequency
404405
for module_path in self.cache.latent_locations.keys():
405-
print(f"# Module: {module_path}")
406+
logger.info(f"# Module: {module_path}")
406407
generate_statistics_cache(
407408
self.cache.tokens[module_path],
408409
self.cache.latent_locations[module_path],
@@ -493,7 +494,7 @@ def generate_statistics_cache(
493494
num_alive = counts.shape[0]
494495
fraction_alive = num_alive / width
495496
if verbose:
496-
print(f"Fraction of latents alive: {fraction_alive:%}")
497+
logger.info(f"Fraction of latents alive: {fraction_alive:%}")
497498
# Compute densities of latents
498499
densities = counts / total_n_tokens
499500

@@ -502,8 +503,12 @@ def generate_statistics_cache(
502503
# How many fired more than 10% of the time
503504
ten_percent = (densities > 0.1).sum() / width
504505
if verbose:
505-
print(f"Fraction of latents fired more than 1% of the time: {one_percent:%}")
506-
print(f"Fraction of latents fired more than 10% of the time: {ten_percent:%}")
506+
logger.info(
507+
f"Fraction of latents fired more than 1% of the time: {one_percent:%}"
508+
)
509+
logger.info(
510+
f"Fraction of latents fired more than 10% of the time: {ten_percent:%}"
511+
)
507512
# Try to estimate simple feature frequency
508513
split_indices = torch.cumsum(counts, dim=0)
509514
activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1])
@@ -525,8 +530,10 @@ def generate_statistics_cache(
525530
single_token_fraction = maybe_single_token_features / num_alive
526531
strong_token_fraction = num_single_token_features / num_alive
527532
if verbose:
528-
print(f"Fraction of weak single token latents: {single_token_fraction:%}")
529-
print(f"Fraction of strong single token latents: {strong_token_fraction:%}")
533+
logger.info(f"Fraction of weak single token latents: {single_token_fraction:%}")
534+
logger.info(
535+
f"Fraction of strong single token latents: {strong_token_fraction:%}"
536+
)
530537

531538
return CacheStatistics(
532539
frac_alive=float(fraction_alive),

delphi/latents/constructors.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch import Tensor
1212
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
1313

14+
from delphi import logger
15+
1416
from ..config import ConstructorConfig
1517
from .latents import (
1618
ActivatingExample,
@@ -25,7 +27,7 @@
2527
def get_model(name: str, device: str = "cuda") -> SentenceTransformer:
2628
global model_cache
2729
if (name, device) not in model_cache:
28-
print(f"Loading model {name} on device {device}")
30+
logger.info(f"Loading model {name} on device {device}")
2931
model_cache[(name, device)] = SentenceTransformer(name, device=device)
3032
return model_cache[(name, device)]
3133

@@ -284,7 +286,9 @@ def constructor(
284286
for toks, acts in zip(token_windows, act_windows)
285287
]
286288
if len(record.examples) < min_examples:
287-
print(f"Not enough examples to explain the latent: {len(record.examples)}")
289+
logger.warning(
290+
f"Not enough examples to explain the latent: {len(record.examples)}"
291+
)
288292
# Not enough examples to explain the latent
289293
return None
290294

@@ -404,7 +408,7 @@ def faiss_non_activation_windows(
404408

405409
# Check if we have enough non-activating examples
406410
if available_indices.numel() < n_not_active:
407-
print("Not enough non-activating examples available")
411+
logger.warning("Not enough non-activating examples available")
408412
return []
409413

410414
# Reshape tokens to get context windows
@@ -426,7 +430,7 @@ def faiss_non_activation_windows(
426430
]
427431

428432
if not activating_texts:
429-
print("No activating examples available")
433+
logger.warning("No activating examples available")
430434
return []
431435

432436
# Create unique cache keys for both activating and non-activating texts
@@ -451,17 +455,17 @@ def faiss_non_activation_windows(
451455
if cache_enabled and non_activating_cache_file.exists():
452456
try:
453457
index = faiss.read_index(str(non_activating_cache_file), faiss.IO_FLAG_MMAP)
454-
print(f"Loaded non-activating index from {non_activating_cache_file}")
458+
logger.info(f"Loaded non-activating index from {non_activating_cache_file}")
455459
except Exception as e:
456-
print(f"Error loading cached embeddings: {repr(e)}")
460+
logger.warning(f"Error loading cached embeddings: {repr(e)}")
457461

458462
if index is None:
459-
print("Decoding non-activating tokens...")
463+
logger.info("Decoding non-activating tokens...")
460464
non_activating_texts = [
461465
"".join(tokenizer.batch_decode(tokens)) for tokens in non_activating_tokens
462466
]
463467

464-
print("Computing non-activating embeddings...")
468+
logger.info("Computing non-activating embeddings...")
465469
non_activating_embeddings = get_model(embedding_model).encode(
466470
non_activating_texts, show_progress_bar=False
467471
)
@@ -472,26 +476,30 @@ def faiss_non_activation_windows(
472476
if cache_enabled:
473477
os.makedirs(cache_path, exist_ok=True)
474478
faiss.write_index(index, str(non_activating_cache_file))
475-
print(f"Cached non-activating embeddings to {non_activating_cache_file}")
479+
logger.info(
480+
f"Cached non-activating embeddings to {non_activating_cache_file}"
481+
)
476482

477483
activating_embeddings = None
478484
if cache_enabled and activating_cache_file.exists():
479485
try:
480486
activating_embeddings = np.load(activating_cache_file)
481-
print(f"Loaded cached activating embeddings from {activating_cache_file}")
487+
logger.info(
488+
f"Loaded cached activating embeddings from {activating_cache_file}"
489+
)
482490
except Exception as e:
483-
print(f"Error loading cached embeddings: {repr(e)}")
491+
logger.warning(f"Error loading cached embeddings: {repr(e)}")
484492
# Compute embeddings for activating examples if not cached
485493
if activating_embeddings is None:
486-
print("Computing activating embeddings...")
494+
logger.info("Computing activating embeddings...")
487495
activating_embeddings = get_model(embedding_model).encode(
488496
activating_texts, show_progress_bar=False
489497
)
490498
# Cache the embeddings
491499
if cache_enabled:
492500
os.makedirs(cache_path, exist_ok=True)
493501
np.save(activating_cache_file, activating_embeddings)
494-
print(f"Cached activating embeddings to {activating_cache_file}")
502+
logger.info(f"Cached activating embeddings to {activating_cache_file}")
495503

496504
# Search for the nearest neighbors to each activating example
497505
collected_indices = set()
@@ -618,7 +626,9 @@ def neighbour_non_activation_windows(
618626
)
619627
number_examples += examples_used
620628
if len(all_examples) == 0:
621-
print("No examples found, falling back to random non-activating examples")
629+
logger.warning(
630+
"No examples found, falling back to random non-activating examples"
631+
)
622632
non_active_indices = not_active_mask.nonzero(as_tuple=False).squeeze()
623633

624634
return random_non_activating_windows(
@@ -655,7 +665,7 @@ def random_non_activating_windows(
655665
# If this happens it means that the latent is active in every window,
656666
# so it is a bad latent
657667
if available_indices.numel() < n_not_active:
658-
print("No available randomly sampled non-activating sequences")
668+
logger.warning("No available randomly sampled non-activating sequences")
659669
return []
660670
else:
661671
random_indices = torch.randint(

delphi/latents/samplers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
PreTrainedTokenizerFast,
77
)
88

9+
from delphi import logger
10+
911
from ..config import SamplerConfig
10-
from ..logger import logger
1112
from .latents import ActivatingExample, LatentRecord
1213

1314

0 commit comments

Comments
 (0)