diff --git a/README.md b/README.md index 412de54..d8852a5 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The overall pipeline is as follows:

+style="width:5.02in;height:0.71in" />

diff --git a/index_files/figure-commonmark/mermaid-figure-1.png b/index_files/figure-commonmark/mermaid-figure-1.png index b3708b2..016b0b4 100644 Binary files a/index_files/figure-commonmark/mermaid-figure-1.png and b/index_files/figure-commonmark/mermaid-figure-1.png differ diff --git a/nbs/01_filter.ipynb b/nbs/01_filter.ipynb index e7b0e3c..1ced968 100644 --- a/nbs/01_filter.ipynb +++ b/nbs/01_filter.ipynb @@ -33,6 +33,7 @@ "import os\n", "import random\n", "import re\n", + "from transformers import AutoTokenizer\n", "\n", "import dill as pickle\n", "import networkit as nk\n", @@ -146,6 +147,43 @@ "assert check_char_repetition(test_str, char_repetition_len=3, char_repetition_threshold=0.2) == False" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def check_document_length_tokenized(\n", + " document:str, #document to be analyzed\n", + " tokenizer:AutoTokenizer, #tokenizer to tokenize the documents\n", + " document_lower_len_threshold:int = 10, #document length threshold to filter\n", + ") -> bool: #returns True if doument length is above threshold else False\n", + " \"\"\"\n", + " Returns True if it's above the threshold, else returns False\n", + " \"\"\"\n", + " tokenized = tokenizer(document).input_ids\n", + " if len(tokenized) > document_lower_len_threshold:\n", + " return True\n", + " else:\n", + " return False\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_str_false = \"single tok\"\n", + "test_str_true = \"Hello this is a long text document with the hopes of being greater than threshold\"\n", + "\n", + "trial_tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n", + "assert check_document_length_tokenized(test_str_false,tokenizer=trial_tokenizer) == False\n", + "assert check_document_length_tokenized(test_str_true,tokenizer=trial_tokenizer) == True\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -733,7 +771,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.8 ('squeakily')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index 804fbed..8676a45 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -5,4 +5,5 @@ website: - 00_core.ipynb - 01_filter.ipynb - 02_clean.ipynb - - 03_helpers.ipynb \ No newline at end of file + - 03_helpers.ipynb + - 04_tutorials.ipynb diff --git a/settings.ini b/settings.ini index f5a1755..1eb3e37 100644 --- a/settings.ini +++ b/settings.ini @@ -25,7 +25,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = CarperAI -requirements = datasketch==1.5.8 datasets==2.7.1 Faker==15.3.3 fastcore networkit rich +requirements = datasketch==1.5.8 datasets==2.7.1 Faker==15.3.3 fastcore networkit rich transformers dev_requirements = nbdev scrubadub black_formatting = False readme_nb = index.ipynb diff --git a/squeakily/_modidx.py b/squeakily/_modidx.py index 556a777..dfdaec8 100644 --- a/squeakily/_modidx.py +++ b/squeakily/_modidx.py @@ -28,6 +28,8 @@ 'squeakily.filter._jaccard_similarity': ('filter.html#_jaccard_similarity', 'squeakily/filter.py'), 'squeakily.filter._query_content': ('filter.html#_query_content', 'squeakily/filter.py'), 'squeakily.filter.check_char_repetition': ('filter.html#check_char_repetition', 'squeakily/filter.py'), + 'squeakily.filter.check_document_length_tokenized': ( 'filter.html#check_document_length_tokenized', + 'squeakily/filter.py'), 'squeakily.filter.check_exact_match': ('filter.html#check_exact_match', 'squeakily/filter.py'), 'squeakily.filter.check_flagged_words': ('filter.html#check_flagged_words', 'squeakily/filter.py'), 'squeakily.filter.minhash_dedup': ('filter.html#minhash_dedup', 'squeakily/filter.py')}, diff --git a/squeakily/filter.py b/squeakily/filter.py index e933bd8..0044402 100644 --- a/squeakily/filter.py +++ b/squeakily/filter.py @@ -1,7 +1,8 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_filter.ipynb. # %% auto 0 -__all__ = ['logger', 'MINHASH_SEED', 'NON_ALPHA', 'lsh', 'dup_ids', 'check_char_repetition', 'check_exact_match', 'minhash_dedup'] +__all__ = ['logger', 'MINHASH_SEED', 'NON_ALPHA', 'lsh', 'dup_ids', 'check_char_repetition', 'check_document_length_tokenized', + 'check_exact_match', 'minhash_dedup'] # %% ../nbs/01_filter.ipynb 2 import datasets @@ -12,6 +13,7 @@ import os import random import re +from transformers import AutoTokenizer import dill as pickle import networkit as nk @@ -80,9 +82,25 @@ def check_char_repetition( return char_rep_ratio <= char_repetition_threshold # %% ../nbs/01_filter.ipynb 8 +def check_document_length_tokenized( + document:str, #document to be analyzed + tokenizer:AutoTokenizer, #tokenizer to tokenize the documents + document_lower_len_threshold:int = 10, #document length threshold to filter +) -> bool: #returns True if doument length is above threshold else False + """ + Returns True if it's above the threshold, else returns False + """ + tokenized = tokenizer(document).input_ids + if len(tokenized) > document_lower_len_threshold: + return True + else: + return False + + +# %% ../nbs/01_filter.ipynb 10 def check_exact_match(): pass -# %% ../nbs/01_filter.ipynb 9 +# %% ../nbs/01_filter.ipynb 11 def _flag_word_ratio( doc: str, # document to be analyzed flagged_words: list, # list of flagged words @@ -101,7 +119,7 @@ def _flag_word_ratio( flagged_words_ratio = 1.0 return flagged_words_ratio -# %% ../nbs/01_filter.ipynb 10 +# %% ../nbs/01_filter.ipynb 12 def check_flagged_words( document: str, # document to be analyzed flagged_words: list = flagged_words["en"], # list of flagged words @@ -121,7 +139,7 @@ def check_flagged_words( cond = flagged_words_ratio <= flagged_words_threshold return cond -# %% ../nbs/01_filter.ipynb 15 +# %% ../nbs/01_filter.ipynb 17 multiprocessing.set_start_method("fork", force=True) MINHASH_SEED = 115 @@ -132,7 +150,7 @@ def check_flagged_words( lsh: MinHashLSH = None dup_ids: set[int] = None -# %% ../nbs/01_filter.ipynb 16 +# %% ../nbs/01_filter.ipynb 18 def _hash_func( idx: int, # The index of the record. content: str, # The content to be hashed. @@ -154,7 +172,7 @@ def _hash_func( m.update_batch([token.encode("utf-8") for token in {t for t in NON_ALPHA.split(content) if t}]) return {"__signature__": m.hashvalues, "__id__": idx} -# %% ../nbs/01_filter.ipynb 18 +# %% ../nbs/01_filter.ipynb 20 def _query_content( idx: int, # The index of the record. signature: np.ndarray, # The MinHash signature of the record to be queried. @@ -176,7 +194,7 @@ def _query_content( "__id__": idx, } -# %% ../nbs/01_filter.ipynb 20 +# %% ../nbs/01_filter.ipynb 22 def _jaccard_similarity( s1: str, # The first string to compare. s2: str # The second string to compare. @@ -188,7 +206,7 @@ def _jaccard_similarity( tokens2 = set([t for t in NON_ALPHA.split(s2) if t.strip()]) return len(tokens1 & tokens2) / max(1, len(tokens1 | tokens2)) -# %% ../nbs/01_filter.ipynb 22 +# %% ../nbs/01_filter.ipynb 24 def _calculate_average_false_positive_rate( clusters: list[list[int]], # The clusters of duplicate records. reference_records: Dataset, # The reference records. @@ -234,7 +252,7 @@ def _calculate_average_false_positive_rate( logger.info(f"- Mean: {np.mean(deltas):0.2f}") logger.info(f"- Std : {np.std(deltas):0.2f}") -# %% ../nbs/01_filter.ipynb 23 +# %% ../nbs/01_filter.ipynb 25 def _find_duplicate_communities( records: Dataset, # The dataset that contains both `__id__` and `__neighbors__`. community_detection: bool, # Whether to use community detection to find the duplicate communities, or to use the connected components. @@ -291,7 +309,7 @@ def _find_duplicate_communities( return to_remove -# %% ../nbs/01_filter.ipynb 24 +# %% ../nbs/01_filter.ipynb 26 def minhash_dedup( ds, # The dataset to deduplicate. column, # The column to use for deduplication.