diff --git a/evaluate.py b/evaluate.py index bace801..66dd2ae 100644 --- a/evaluate.py +++ b/evaluate.py @@ -35,7 +35,7 @@ def log_mistakes_report(mistakes: pd.DataFrame, category: str, eval_timestamp: s mistakes.to_csv(f"{eval_directory}/mistakes_{eval_timestamp}_{category}.csv", index=False) -def evaluate_filter(category: str, filter_function: function, dataset: pd.DataFrame, eval_timestamp: str) -> dict: +def evaluate_filter(category: str, filter_function, dataset: pd.DataFrame, eval_timestamp: str) -> dict: """ Evaluate the classification performance of the provided filter @@ -48,7 +48,14 @@ def evaluate_filter(category: str, filter_function: function, dataset: pd.DataFr Returns: dict: The classification report of the filter """ - filter_judgments = dataset["shortened_text"].progress_apply(filter_function) + filter_judgments = [] + for i in tqdm(range(len(dataset))): + try: + filter_judgments.append(filter_function(dataset["shortened_text"][i])) + except: + filter_judgments.append(-1) + + # filter_judgments = dataset["shortened_text"].progress_apply(filter_function) filter_labels = dataset["Category"].progress_apply(lambda c: c == category) report_dict = classification_report(filter_labels, filter_judgments, output_dict=True) evaluation_log = { @@ -75,7 +82,7 @@ def evaluate(filters: dict): Args: filters (dict): The filters to evaluate. The key is the name of the category and value is the filter function. """ - dataset = pd.read_csv("datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted.csv") + dataset = pd.read_csv("datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted") eval_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") eval_results = [] for category, filter_function in filters.items(): diff --git a/filters/highly_duplicated_filter.py b/filters/highly_duplicated_filter.py new file mode 100644 index 0000000..27d899e --- /dev/null +++ b/filters/highly_duplicated_filter.py @@ -0,0 +1,57 @@ +from collections import Counter +from typing import Callable, List + +import pandas as pd + +def _concat_token_indices(token_indices: List[int], delimiter: str = '_') -> str: + """ + Concatenates a list of tokens into a single string. + + Args: + token_indices (List[int]): List of token indices to concatenate. + delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. + + Returns: + str: Concatenated string of tokens indices. + """ + return delimiter.join([str(t) for t in token_indices]) + +def generate_sequence_histogram(token_indices: pd.Series, delimiter: str = '_') -> Counter[str, int]: + """ + Generates a histogram from a Pandas Series of token indices. The histogram is based on the concatenated strings of token indices. + + Args: + token_index_sequences (pd.Series): Pandas Series of token indices. + delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. + + Returns: + Counter[str, int]: Histogram of strings of token indices. + """ + return Counter(token_indices.apply(lambda x: _concat_token_indices(x, delimiter=delimiter))) + +def get_highly_duplicated_filter_func(histogram: Counter[str, int], frequency_threshold: int = 1, delimiter: str = '_') -> Callable[[List[int]], bool]: + """ + Generates a filter function that checks if a list of token indices is highly duplicated. + + Args: + histogram (Counter[str, int]): Histogram of strings of token indices. + frequency_threshold (int, optional): Frequency threshold to use for filtering. Defaults to 1. + delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. + + Returns: + Callable[[List[int]], bool]: Filter function that checks if a list of token indices is highly duplicated. + """ + def _highly_duplicated_filter_func(token_indices: List[int]) -> bool: + """ + Checks if a list of token indices is highly duplicated. + + Args: + token_indices (List[int]): List of token indices to check. + + Returns: + bool: True if the list of token indices is highly duplicated, False otherwise. + """ + token_string = _concat_token_indices(token_indices, delimiter=delimiter) + return histogram[token_string] > frequency_threshold + + return _highly_duplicated_filter_func diff --git a/filters/pattern_incrementing.py b/filters/pattern_incrementing.py index 53711e9..4fba1c6 100644 --- a/filters/pattern_incrementing.py +++ b/filters/pattern_incrementing.py @@ -1,2 +1,67 @@ -def incrementing_sequences_filter(text): +import re + +def incrementing_sequences_filter(text: str) -> bool: + """ + This sequence will classify a given text is an incrementing sequence or not. + + Args: + text (str): The current sequence to be classified. + + Returns: + bool: Whether the sequence is an incrementing sequence or not. + """ + # Split by seperators between text + possible_seperators = list(set(re.findall(r'(?<=\d)(\D+)(?=\d)', text))) + [" "] + ["\n"] + for seperator in possible_seperators: + # seperator = "" + # reading = None + # prev_char = None + # for index, character in enumerate(text): + # next_char = text[index + 1] if index + 1 < len(text) else "" + # if prev_char is None: + # prev_char = character + # if not character.isdigit() and not next_char.isdigit(): + # reading = True + # seperator += character + # if character.isdigit() and reading is True: + # break + + # prev_char = character + split_text = text.split(" " if seperator == "" else seperator) + + # trim the end if the final character(s) is a seperator + trailing_seperator = "" + for sep_index in range(len(seperator)): + if text.split(seperator)[-1][sep_index - 1:] == seperator[:sep_index + 1]: + trailing_seperator += seperator[:sep_index + 1] + else: + break + split_text[-1] = split_text[-1][:-len(trailing_seperator)] + + # Check if the sequence is just a list of digits + if len(split_text) == 1: + failed = False + prev_char = None + is_decrementing = None + for char in split_text[0]: + if char.isdigit(): + if prev_char is None and is_decrementing is None: + prev_char = char + elif is_decrementing is None: + is_decrementing = int(char) < int(prev_char) + prev_char = char + elif is_decrementing and (int(char) < int(prev_char)): + prev_char = char + elif not is_decrementing and (int(char) > int(prev_char)): + prev_char = char + else: + failed = True + break + else: + failed = True + break + if failed: + return False + + return True \ No newline at end of file diff --git a/filters/test_highly_duplicated_filter.py b/filters/test_highly_duplicated_filter.py new file mode 100644 index 0000000..3358020 --- /dev/null +++ b/filters/test_highly_duplicated_filter.py @@ -0,0 +1,30 @@ +import pandas as pd + +from .highly_duplicated_filter import get_highly_duplicated_filter_func, generate_sequence_histogram + +def test_highly_duplicated_filter_on_seen_indices(): + data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) + histogram = generate_sequence_histogram(data) + threshold = 1 + filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) + + sample = [4, 5, 6] + assert filter_func(sample) == True + +def test_highly_duplicated_filter_on_unseen_indices(): + data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) + histogram = generate_sequence_histogram(data) + threshold = 1 + filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) + + sample = [7, 8, 9] + assert filter_func(sample) == False + +def test_highly_duplicated_filter_on_infrequent_indices(): + data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) + histogram = generate_sequence_histogram(data) + threshold = 2 + filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) + + sample = [4, 5, 6] + assert filter_func(sample) == False diff --git a/filters/test_pattern_incrementing.py b/filters/test_pattern_incrementing.py new file mode 100644 index 0000000..add828f --- /dev/null +++ b/filters/test_pattern_incrementing.py @@ -0,0 +1,44 @@ +from .pattern_incrementing import incrementing_sequences_filter + + +def test_pattern_incrementing_no_space(): + text = "123456789" + assert incrementing_sequences_filter(text) == True + + +def test_pattern_incrementing_no_space_with_char(): + text = "1A23456789" + assert incrementing_sequences_filter(text) == False + + +def test_pattern_incrementing(): + text = "12.8. 12.9. 13.0. 13.1. 13.2. 13.3." + assert incrementing_sequences_filter(text) == True + + +def test_pattern_new_lines_incrementing(): + text = "128.\n129.\n130.\n131.\n132.\n133." + assert incrementing_sequences_filter(text) == True + + +def test_pattern_list_incrementing(): + text = "- 128.\n- 129.\n- 130.\n- 131.\n- 132.\n- 133." + assert incrementing_sequences_filter(text) == True + + +def test_incrementing_nonnumerical_pattern(): + text = """ +![](edinbmedj75052-0047-b){#f5.123} + +![](edinbmedj75052-0049-a){#f6.125} + +![](edinbmedj75052-0049-b){#f7.125} + +![](edin +""" + assert incrementing_sequences_filter(text) == True + + +def test_incrementing_seminnumerical_pattern(): + text = "A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1" + assert incrementing_sequences_filter(text) == True diff --git a/requirements.txt b/requirements.txt index 4c7ccd4..805330d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pandas numpy +scikit-learn torch torchvision torchaudio @@ -9,4 +10,4 @@ datasets tqdm black pylint -scikit-learn \ No newline at end of file +pytest \ No newline at end of file diff --git a/working_dirs/kyle/taxonemy_analysis/eval_set_v2.ipynb b/working_dirs/kyle/taxonemy_analysis/eval_set_v2.ipynb new file mode 100644 index 0000000..e193ba4 --- /dev/null +++ b/working_dirs/kyle/taxonemy_analysis/eval_set_v2.ipynb @@ -0,0 +1,1698 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from datasets import load_dataset\n", + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexperplexitymemorizedis_codeshortened_textCategoryNote
0928833.687500TrueTrue}{-69pt}\\n \\begin{document}$u_{n}\\rightarrow u...codelatex
16858753.837891TrueFalsealesSite: All American Trannies\\n\\nFor Search ...nlNaN
29731522.884766TrueFalse18>::type T18;\\n typedef map<T0, T1, T2, T3, T...pattern-incrementingNaN
310169811.056641TrueTrue]{minimal}\\n \\usepackage{amsmath}\\n \\usepackag...codelatex
410893713.882812TrueTrue: 1,\\n\",\\n \"'col-md-push-6' : 1,\\n\",\\n \"'col-...pattern-incrementingNaN
........................
249525813513.392578FalseTrue2*y**2 + 6*y. Let z(g) = -3*g**2 - 7*g - 7. Le...code+nlmath
249625835343.597656FalseFalse039 ### ###',\\n '049 ### ###',\\n '050 ### ###'...pattern-incrementingNaN
249725846953.710938FalseTrue.1, -1?\\n-1\\nWhat is the second biggest value ...code+nlmath
249825861703.080078FalseTruepublic DbUpdateException()\\n {\\n }\\n\\n /// <su...codeNaN
249925930682.578125FalseTrueCLANG_WARN_BOOL_CONVERSION = YES;\\n CLANG_WARN...codeNaN
\n", + "

2500 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " index perplexity memorized is_code \n", + "0 92883 3.687500 True True \\\n", + "1 685875 3.837891 True False \n", + "2 973152 2.884766 True False \n", + "3 1016981 1.056641 True True \n", + "4 1089371 3.882812 True True \n", + "... ... ... ... ... \n", + "2495 2581351 3.392578 False True \n", + "2496 2583534 3.597656 False False \n", + "2497 2584695 3.710938 False True \n", + "2498 2586170 3.080078 False True \n", + "2499 2593068 2.578125 False True \n", + "\n", + " shortened_text Category \n", + "0 }{-69pt}\\n \\begin{document}$u_{n}\\rightarrow u... code \\\n", + "1 alesSite: All American Trannies\\n\\nFor Search ... nl \n", + "2 18>::type T18;\\n typedef map\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indextokens__index_level_0__
0441[5584, 4196, 1228, 187, 1036, 4, 209, 21723, 2...441
1447[50262, 61, 2099, 92, 8861, 94, 187, 50262, 61...447
2792[475, 50272, 953, 24781, 778, 320, 908, 281, 1...792
31539[424, 380, 16101, 313, 433, 17889, 3104, 10, 2...1539
41705[3498, 2262, 2369, 40, 736, 13, 3956, 27, 21, ...1705
............
411443146431199[281, 320, 669, 8604, 60, 805, 431, 1019, 8402...2287199
411444146431278[588, 1705, 285, 8415, 634, 1895, 15, 30952, 3...2287278
411445146431294[15468, 13, 50275, 13743, 13, 50275, 15220, 13...2287294
411446146431588[27, 330, 14788, 10334, 14, 3429, 27, 577, 28,...2287588
411447146431592[1406, 485, 15, 23780, 300, 2473, 285, 12698, ...2287592
\n", + "

411448 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " index tokens \n", + "0 441 [5584, 4196, 1228, 187, 1036, 4, 209, 21723, 2... \\\n", + "1 447 [50262, 61, 2099, 92, 8861, 94, 187, 50262, 61... \n", + "2 792 [475, 50272, 953, 24781, 778, 320, 908, 281, 1... \n", + "3 1539 [424, 380, 16101, 313, 433, 17889, 3104, 10, 2... \n", + "4 1705 [3498, 2262, 2369, 40, 736, 13, 3956, 27, 21, ... \n", + "... ... ... \n", + "411443 146431199 [281, 320, 669, 8604, 60, 805, 431, 1019, 8402... \n", + "411444 146431278 [588, 1705, 285, 8415, 634, 1895, 15, 30952, 3... \n", + "411445 146431294 [15468, 13, 50275, 13743, 13, 50275, 15220, 13... \n", + "411446 146431588 [27, 330, 14788, 10334, 14, 3429, 27, 577, 28,... \n", + "411447 146431592 [1406, 485, 15, 23780, 300, 2473, 285, 12698, ... \n", + "\n", + " __index_level_0__ \n", + "0 441 \n", + "1 447 \n", + "2 792 \n", + "3 1539 \n", + "4 1705 \n", + "... ... \n", + "411443 2287199 \n", + "411444 2287278 \n", + "411445 2287294 \n", + "411446 2287588 \n", + "411447 2287592 \n", + "\n", + "[411448 rows x 3 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pythia_70m_memories = load_dataset(\"EleutherAI/pythia-memorized-evals\", split=\"deduped.70m\").to_pandas()\n", + "pythia_70m_memories" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexperplexitymemorizedis_codeshortened_textCategoryNotetokens__index_level_0__
0928833.687500TrueTrue}{-69pt}\\n \\begin{document}$u_{n}\\rightarrow u...codelatex[8699, 2090, 431, 94, 187, 50262, 61, 2043, 92...92883
16858753.837891TrueFalsealesSite: All American Trannies\\n\\nFor Search ...nlNaN[2339, 27327, 27, 1876, 2448, 1535, 1136, 447,...685875
29731522.884766TrueFalse18>::type T18;\\n typedef map<T0, T1, T2, T3, T...pattern-incrementingNaN[1093, 14157, 881, 308, 1093, 28, 187, 50266, ...973152
310169811.056641TrueTrue]{minimal}\\n \\usepackage{amsmath}\\n \\usepackag...codelatex[1019, 8402, 94, 187, 50262, 61, 2099, 92, 879...1016981
410893713.882812TrueTrue: 1,\\n\",\\n \"'col-md-push-6' : 1,\\n\",\\n \"'col-...pattern-incrementingNaN[8, 1163, 337, 1337, 79, 995, 187, 50274, 2789...1089371
..............................
248625558323.667969FalseTruedip)</pre>\\n</li>\\n</ul>\\n<a name=\"cornerRadiu...codeNaN[31665, 17266, 3456, 31, 187, 870, 965, 31, 18...267832
248825561542.320312FalseFalseLEASE COME TO MEXICO CITY PLEASE COME TO MEXIC...pattern-repeatingNaN[26084, 8610, 38, 5935, 353, 4237, 24218, 4589...268154
249125606554.386719FalseTrueNL_WABMON_4 = 131141\\n X_NL_WABMON_5 = 131142\\...pattern-incrementingNaN[19214, 64, 56, 2925, 22362, 64, 21, 50276, 30...272655
249325708524.101562FalseFalseWITH OCESAPLEASE COME MEXICO CITY WITH OCESAPL...pattern-repeatingNaN[9277, 27202, 1410, 2088, 26084, 8610, 38, 353...282852
249625835343.597656FalseFalse039 ### ###',\\n '049 ### ###',\\n '050 ### ###'...pattern-incrementingNaN[18832, 209, 4118, 209, 4118, 1383, 187, 50270...295534
\n", + "

1430 rows × 9 columns

\n", + "
" + ], + "text/plain": [ + " index perplexity memorized is_code \n", + "0 92883 3.687500 True True \\\n", + "1 685875 3.837891 True False \n", + "2 973152 2.884766 True False \n", + "3 1016981 1.056641 True True \n", + "4 1089371 3.882812 True True \n", + "... ... ... ... ... \n", + "2486 2555832 3.667969 False True \n", + "2488 2556154 2.320312 False False \n", + "2491 2560655 4.386719 False True \n", + "2493 2570852 4.101562 False False \n", + "2496 2583534 3.597656 False False \n", + "\n", + " shortened_text Category \n", + "0 }{-69pt}\\n \\begin{document}$u_{n}\\rightarrow u... code \\\n", + "1 alesSite: All American Trannies\\n\\nFor Search ... nl \n", + "2 18>::type T18;\\n typedef map\\n\\n\\n\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indextokensis_memorized
018[15, 46525, 3439, 2526, 187, 14, 17, 15, 1036,...False
143[273, 22523, 18595, 275, 643, 3054, 2085, 3081...False
286[749, 10580, 273, 575, 5, 44, 64, 79, 5, 534, ...False
3110[12556, 187, 71, 437, 285, 45965, 13, 285, 253...False
4112[3847, 277, 2631, 449, 346, 1552, 310, 417, 82...False
............
4999995146431872[3117, 393, 6040, 416, 393, 5786, 393, 50, 5, ...False
4999996146431904[187, 6067, 1783, 2722, 326, 14108, 1638, 3400...False
4999997146431927[704, 39660, 1051, 187, 29, 56, 2711, 8537, 37...False
4999998146431960[14, 34552, 15390, 1253, 15280, 285, 1108, 447...False
4999999146431973[38630, 14716, 247, 15846, 8651, 5763, 15, 831...False
\n", + "

5000000 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " index tokens \n", + "0 18 [15, 46525, 3439, 2526, 187, 14, 17, 15, 1036,... \\\n", + "1 43 [273, 22523, 18595, 275, 643, 3054, 2085, 3081... \n", + "2 86 [749, 10580, 273, 575, 5, 44, 64, 79, 5, 534, ... \n", + "3 110 [12556, 187, 71, 437, 285, 45965, 13, 285, 253... \n", + "4 112 [3847, 277, 2631, 449, 346, 1552, 310, 417, 82... \n", + "... ... ... \n", + "4999995 146431872 [3117, 393, 6040, 416, 393, 5786, 393, 50, 5, ... \n", + "4999996 146431904 [187, 6067, 1783, 2722, 326, 14108, 1638, 3400... \n", + "4999997 146431927 [704, 39660, 1051, 187, 29, 56, 2711, 8537, 37... \n", + "4999998 146431960 [14, 34552, 15390, 1253, 15280, 285, 1108, 447... \n", + "4999999 146431973 [38630, 14716, 247, 15846, 8651, 5763, 15, 831... \n", + "\n", + " is_memorized \n", + "0 False \n", + "1 False \n", + "2 False \n", + "3 False \n", + "4 False \n", + "... ... \n", + "4999995 False \n", + "4999996 False \n", + "4999997 False \n", + "4999998 False \n", + "4999999 False \n", + "\n", + "[5000000 rows x 3 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "non_memories = load_dataset(\"EleutherAI/pile-deduped-pythia-random-sampled\")[\"train\"].to_pandas()\n", + "non_memories" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexperplexitymemorizedis_codeshortened_textCategoryNotetokensis_memorized
1219992302.210938TrueFalseArmenians\",0,\"\",0,\"\",0,0,0,0,0,0,0,0,0,0,0,0,0...codeNaN[37801, 2458, 995, 17, 937, 995, 17, 937, 995,...False
218419992302.210938FalseFalseArmenians\",0,\"\",0,\"\",0,0,0,0,0,0,0,0,0,0,0,0,0...pattern-repeatingNaN[37801, 2458, 995, 17, 937, 995, 17, 937, 995,...False
1728149764.007812TrueTrue/brand-5\\nhttps://m.52010000.cn/brand-6\\nhttps...codeNaN[16, 22374, 14, 22, 187, 3614, 1358, 78, 15, 2...False
2636162182.611328TrueTrueCA5 },\\n { 0x10CE6, 0x10CA6 },\\n { 0x10CE7, 0x...pattern-incrementingNaN[4280, 22, 3572, 187, 50274, 92, 470, 89, 740,...False
8996572332.617188TrueTrue#ERROR!codeNaN[568, 2437, 275, 389, 15, 29762, 15, 26318, 15...False
..............................
249525813513.392578FalseTrue2*y**2 + 6*y. Let z(g) = -3*g**2 - 7*g - 7. Le...code+nlmath[374, 11, 90, 424, 19, 559, 721, 11, 90, 15, 1...False
249625835343.597656FalseFalse039 ### ###',\\n '049 ### ###',\\n '050 ### ###'...pattern-incrementingNaN[18832, 209, 4118, 209, 4118, 1383, 187, 50270...False
249725846953.710938FalseTrue.1, -1?\\n-1\\nWhat is the second biggest value ...code+nlmath[15, 18, 13, 428, 18, 32, 187, 14, 18, 187, 12...False
249825861703.080078FalseTruepublic DbUpdateException()\\n {\\n }\\n\\n /// <su...codeNaN[187, 50270, 4387, 46688, 11241, 5330, 1082, 1...False
249925930682.578125FalseTrueCLANG_WARN_BOOL_CONVERSION = YES;\\n CLANG_WARN...codeNaN[3207, 14375, 64, 24798, 64, 30529, 64, 5707, ...False
\n", + "

1298 rows × 9 columns

\n", + "
" + ], + "text/plain": [ + " index perplexity memorized is_code \n", + "12 1999230 2.210938 True False \\\n", + "2184 1999230 2.210938 False False \n", + "17 2814976 4.007812 True True \n", + "26 3616218 2.611328 True True \n", + "89 9657233 2.617188 True True \n", + "... ... ... ... ... \n", + "2495 2581351 3.392578 False True \n", + "2496 2583534 3.597656 False False \n", + "2497 2584695 3.710938 False True \n", + "2498 2586170 3.080078 False True \n", + "2499 2593068 2.578125 False True \n", + "\n", + " shortened_text Category \n", + "12 Armenians\",0,\"\",0,\"\",0,0,0,0,0,0,0,0,0,0,0,0,0... code \\\n", + "2184 Armenians\",0,\"\",0,\"\",0,0,0,0,0,0,0,0,0,0,0,0,0... pattern-repeating \n", + "17 /brand-5\\nhttps://m.52010000.cn/brand-6\\nhttps... code \n", + "26 CA5 },\\n { 0x10CE6, 0x10CA6 },\\n { 0x10CE7, 0x... pattern-incrementing \n", + "89 #ERROR! code \n", + "... ... ... \n", + "2495 2*y**2 + 6*y. Let z(g) = -3*g**2 - 7*g - 7. Le... code+nl \n", + "2496 039 ### ###',\\n '049 ### ###',\\n '050 ### ###'... pattern-incrementing \n", + "2497 .1, -1?\\n-1\\nWhat is the second biggest value ... code+nl \n", + "2498 public DbUpdateException()\\n {\\n }\\n\\n /// \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexmemorizedperplexityis_codepromptsequenceCodeIncrementalRepetitiveHighly DuplicatedTemplatingNatural LanguageRandomOtherNotes
17201042625False4.777344True4\\nLet f be ((-1)/2)/(2/4). Let a = 6 - 5. Let...4\\nLet f be ((-1)/2)/(2/4). Let a = 6 - 5. Let...
1412378406False4.007812True2/2 + 3*r - 3. Let g(x) be the first derivativ...2/2 + 3*r - 3. Let g(x) be the first derivativ...
17941188414False1.437500Trueusepackage{amsmath}\\n\\usepackage{wasysym} \\n\\u...usepackage{amsmath}\\n\\usepackage{wasysym} \\n\\u...
19851572476False3.982422Truev - 2*v + 6 = 0. Let k be (3/v)/(2/(-8)). Let ...v - 2*v + 6 = 0. Let k be (3/v)/(2/(-8)). Let ...
51060937406True2.820312True\\n\\n![](amjdentsci80652-0039){#sp2.143}\\n\\n![]...\\n\\n![](amjdentsci80652-0039){#sp2.143}\\n\\n![]...
................................................
1055121637700True1.520508True=\"1.0\" encoding=\"UTF-8\"?>\\n<!DOCTYPE plist PUB...=\"1.0\" encoding=\"UTF-8\"?>\\n<!DOCTYPE plist PUB...
717827439True3.263672Trueper Team is already on the scene....<?xml vers...per Team is already on the scene....<?xml vers...
20111644096False3.005859False11.9 ± 2.0 11.5 ± 2.0 11.2 ± 2.2 \\<11.9 ± 2.0 11.5 ± 2.0 11.2 ± 2.2 \\...
1028118099130True2.289062Falseref 8, ref 9, ref 10, ref 11, ref 12, ref 13,...ref 8, ref 9, ref 10, ref 11, ref 12, ref 13,...
1026118045648True2.908203TrueISA as two detectors, so that the signal in ea...ISA as two detectors, so that the signal in ea...
\n", + "

2499 rows × 15 columns

\n", + "" + ], + "text/plain": [ + " index memorized perplexity is_code \n", + "1720 1042625 False 4.777344 True \\\n", + "1412 378406 False 4.007812 True \n", + "1794 1188414 False 1.437500 True \n", + "1985 1572476 False 3.982422 True \n", + "510 60937406 True 2.820312 True \n", + "... ... ... ... ... \n", + "1055 121637700 True 1.520508 True \n", + "71 7827439 True 3.263672 True \n", + "2011 1644096 False 3.005859 False \n", + "1028 118099130 True 2.289062 False \n", + "1026 118045648 True 2.908203 True \n", + "\n", + " prompt \n", + "1720 4\\nLet f be ((-1)/2)/(2/4). Let a = 6 - 5. Let... \\\n", + "1412 2/2 + 3*r - 3. Let g(x) be the first derivativ... \n", + "1794 usepackage{amsmath}\\n\\usepackage{wasysym} \\n\\u... \n", + "1985 v - 2*v + 6 = 0. Let k be (3/v)/(2/(-8)). Let ... \n", + "510 \\n\\n![](amjdentsci80652-0039){#sp2.143}\\n\\n![]... \n", + "... ... \n", + "1055 =\"1.0\" encoding=\"UTF-8\"?>\\n\\n\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexmemorizedperplexityis_codepromptsequenceCodeIncrementalRepetitiveHighly DuplicatedTemplatingNatural LanguageRandomOtherNotes
1497576734False3.447266True(-1.0, -1.0, 0);\\nglTexCoord2f(0, 0);\\nglVerte...(-1.0, -1.0, 0);\\nglTexCoord2f(0, 0);\\nglVerte...
38945993218True1.752930Truenot use this file except in compliance with t...not use this file except in compliance with t...
1095126795714True4.429688True(c) 2019 Wei Wang <onevcat@gmail.com>\\n//\\n//...(c) 2019 Wei Wang <onevcat@gmail.com>\\n//\\n//...
49658907866True1.357422Falsehttp://www.apache.org/licenses/LICENSE-2.0\\n\\n...http://www.apache.org/licenses/LICENSE-2.0\\n\\n...
1647911620False3.562500True=\"table-fn\"} ...=\"table-fn\"} ...
................................................
20831786406False1.972656True\\n \\usepackage{amssymb} \\n ...\\n \\usepackage{amssymb} \\n ...
1202139649808True1.101562True$\\documentclass[12pt]{minimal}\\n ...$\\documentclass[12pt]{minimal}\\n ...
42850719780True1.749023Truein compliance with the License.\\n// You may o...in compliance with the License.\\n// You may o...
20201659069False3.974609False?\\n3\\nWhat is the ninth root of 113001 to the ...?\\n3\\nWhat is the ninth root of 113001 to the ...
50660582215True1.504883Truegood judgment.//\\n// Generated by class-d...good judgment.//\\n// Generated by class-d...
\n", + "

100 rows × 15 columns

\n", + "" + ], + "text/plain": [ + " index memorized perplexity is_code \n", + "1497 576734 False 3.447266 True \\\n", + "389 45993218 True 1.752930 True \n", + "1095 126795714 True 4.429688 True \n", + "496 58907866 True 1.357422 False \n", + "1647 911620 False 3.562500 True \n", + "... ... ... ... ... \n", + "2083 1786406 False 1.972656 True \n", + "1202 139649808 True 1.101562 True \n", + "428 50719780 True 1.749023 True \n", + "2020 1659069 False 3.974609 False \n", + "506 60582215 True 1.504883 True \n", + "\n", + " prompt \n", + "1497 (-1.0, -1.0, 0);\\nglTexCoord2f(0, 0);\\nglVerte... \\\n", + "389 not use this file except in compliance with t... \n", + "1095 (c) 2019 Wei Wang \\n//\\n//... \n", + "496 http://www.apache.org/licenses/LICENSE-2.0\\n\\n... \n", + "1647 =\"table-fn\"} ... \n", + "... ... \n", + "2083 \\n \\usepackage{amssymb} \\n ... \n", + "1202 $\\documentclass[12pt]{minimal}\\n ... \n", + "428 in compliance with the License.\\n// You may o... \n", + "2020 ?\\n3\\nWhat is the ninth root of 113001 to the ... \n", + "506 good judgment.//\\n// Generated by class-d... \n", + "\n", + " sequence Code Incremental \n", + "1497 (-1.0, -1.0, 0);\\nglTexCoord2f(0, 0);\\nglVerte... \\\n", + "389 not use this file except in compliance with t... \n", + "1095 (c) 2019 Wei Wang \\n//\\n//... \n", + "496 http://www.apache.org/licenses/LICENSE-2.0\\n\\n... \n", + "1647 =\"table-fn\"} ... \n", + "... ... ... ... \n", + "2083 \\n \\usepackage{amssymb} \\n ... \n", + "1202 $\\documentclass[12pt]{minimal}\\n ... \n", + "428 in compliance with the License.\\n// You may o... \n", + "2020 ?\\n3\\nWhat is the ninth root of 113001 to the ... \n", + "506 good judgment.//\\n// Generated by class-d... \n", + "\n", + " Repetitive Highly Duplicated Templating Natural Language Random Other \n", + "1497 \\\n", + "389 \n", + "1095 \n", + "496 \n", + "1647 \n", + "... ... ... ... ... ... ... \n", + "2083 \n", + "1202 \n", + "428 \n", + "2020 \n", + "506 \n", + "\n", + " Notes \n", + "1497 \n", + "389 \n", + "1095 \n", + "496 \n", + "1647 \n", + "... ... \n", + "2083 \n", + "1202 \n", + "428 \n", + "2020 \n", + "506 \n", + "\n", + "[100 rows x 15 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "combined_strat_downsample = pd.concat(\n", + " [combined_joined_eval_set[combined_joined_eval_set[\"memorized\"]].sample(50),\n", + " combined_joined_eval_set[~combined_joined_eval_set[\"memorized\"]].sample(50)]).sample(frac=1)\n", + "\n", + "combined_strat_downsample.to_csv(\"combined_strat_downsample.csv\", index=False)\n", + "combined_strat_downsample" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "memorization", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}