diff --git a/README.md b/README.md
index 412de54..d8852a5 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ The overall pipeline is as follows:
+style="width:5.02in;height:0.71in" />
diff --git a/index_files/figure-commonmark/mermaid-figure-1.png b/index_files/figure-commonmark/mermaid-figure-1.png
index b3708b2..016b0b4 100644
Binary files a/index_files/figure-commonmark/mermaid-figure-1.png and b/index_files/figure-commonmark/mermaid-figure-1.png differ
diff --git a/nbs/01_filter.ipynb b/nbs/01_filter.ipynb
index e7b0e3c..1ced968 100644
--- a/nbs/01_filter.ipynb
+++ b/nbs/01_filter.ipynb
@@ -33,6 +33,7 @@
"import os\n",
"import random\n",
"import re\n",
+ "from transformers import AutoTokenizer\n",
"\n",
"import dill as pickle\n",
"import networkit as nk\n",
@@ -146,6 +147,43 @@
"assert check_char_repetition(test_str, char_repetition_len=3, char_repetition_threshold=0.2) == False"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "def check_document_length_tokenized(\n",
+ " document:str, #document to be analyzed\n",
+ " tokenizer:AutoTokenizer, #tokenizer to tokenize the documents\n",
+ " document_lower_len_threshold:int = 10, #document length threshold to filter\n",
+ ") -> bool: #returns True if doument length is above threshold else False\n",
+ " \"\"\"\n",
+ " Returns True if it's above the threshold, else returns False\n",
+ " \"\"\"\n",
+ " tokenized = tokenizer(document).input_ids\n",
+ " if len(tokenized) > document_lower_len_threshold:\n",
+ " return True\n",
+ " else:\n",
+ " return False\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_str_false = \"single tok\"\n",
+ "test_str_true = \"Hello this is a long text document with the hopes of being greater than threshold\"\n",
+ "\n",
+ "trial_tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n",
+ "assert check_document_length_tokenized(test_str_false,tokenizer=trial_tokenizer) == False\n",
+ "assert check_document_length_tokenized(test_str_true,tokenizer=trial_tokenizer) == True\n"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -733,7 +771,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3.10.8 ('squeakily')",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml
index 804fbed..8676a45 100644
--- a/nbs/sidebar.yml
+++ b/nbs/sidebar.yml
@@ -5,4 +5,5 @@ website:
- 00_core.ipynb
- 01_filter.ipynb
- 02_clean.ipynb
- - 03_helpers.ipynb
\ No newline at end of file
+ - 03_helpers.ipynb
+ - 04_tutorials.ipynb
diff --git a/settings.ini b/settings.ini
index f5a1755..1eb3e37 100644
--- a/settings.ini
+++ b/settings.ini
@@ -25,7 +25,7 @@ keywords = nbdev jupyter notebook python
language = English
status = 3
user = CarperAI
-requirements = datasketch==1.5.8 datasets==2.7.1 Faker==15.3.3 fastcore networkit rich
+requirements = datasketch==1.5.8 datasets==2.7.1 Faker==15.3.3 fastcore networkit rich transformers
dev_requirements = nbdev scrubadub
black_formatting = False
readme_nb = index.ipynb
diff --git a/squeakily/_modidx.py b/squeakily/_modidx.py
index 556a777..dfdaec8 100644
--- a/squeakily/_modidx.py
+++ b/squeakily/_modidx.py
@@ -28,6 +28,8 @@
'squeakily.filter._jaccard_similarity': ('filter.html#_jaccard_similarity', 'squeakily/filter.py'),
'squeakily.filter._query_content': ('filter.html#_query_content', 'squeakily/filter.py'),
'squeakily.filter.check_char_repetition': ('filter.html#check_char_repetition', 'squeakily/filter.py'),
+ 'squeakily.filter.check_document_length_tokenized': ( 'filter.html#check_document_length_tokenized',
+ 'squeakily/filter.py'),
'squeakily.filter.check_exact_match': ('filter.html#check_exact_match', 'squeakily/filter.py'),
'squeakily.filter.check_flagged_words': ('filter.html#check_flagged_words', 'squeakily/filter.py'),
'squeakily.filter.minhash_dedup': ('filter.html#minhash_dedup', 'squeakily/filter.py')},
diff --git a/squeakily/filter.py b/squeakily/filter.py
index e933bd8..0044402 100644
--- a/squeakily/filter.py
+++ b/squeakily/filter.py
@@ -1,7 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_filter.ipynb.
# %% auto 0
-__all__ = ['logger', 'MINHASH_SEED', 'NON_ALPHA', 'lsh', 'dup_ids', 'check_char_repetition', 'check_exact_match', 'minhash_dedup']
+__all__ = ['logger', 'MINHASH_SEED', 'NON_ALPHA', 'lsh', 'dup_ids', 'check_char_repetition', 'check_document_length_tokenized',
+ 'check_exact_match', 'minhash_dedup']
# %% ../nbs/01_filter.ipynb 2
import datasets
@@ -12,6 +13,7 @@
import os
import random
import re
+from transformers import AutoTokenizer
import dill as pickle
import networkit as nk
@@ -80,9 +82,25 @@ def check_char_repetition(
return char_rep_ratio <= char_repetition_threshold
# %% ../nbs/01_filter.ipynb 8
+def check_document_length_tokenized(
+ document:str, #document to be analyzed
+ tokenizer:AutoTokenizer, #tokenizer to tokenize the documents
+ document_lower_len_threshold:int = 10, #document length threshold to filter
+) -> bool: #returns True if doument length is above threshold else False
+ """
+ Returns True if it's above the threshold, else returns False
+ """
+ tokenized = tokenizer(document).input_ids
+ if len(tokenized) > document_lower_len_threshold:
+ return True
+ else:
+ return False
+
+
+# %% ../nbs/01_filter.ipynb 10
def check_exact_match(): pass
-# %% ../nbs/01_filter.ipynb 9
+# %% ../nbs/01_filter.ipynb 11
def _flag_word_ratio(
doc: str, # document to be analyzed
flagged_words: list, # list of flagged words
@@ -101,7 +119,7 @@ def _flag_word_ratio(
flagged_words_ratio = 1.0
return flagged_words_ratio
-# %% ../nbs/01_filter.ipynb 10
+# %% ../nbs/01_filter.ipynb 12
def check_flagged_words(
document: str, # document to be analyzed
flagged_words: list = flagged_words["en"], # list of flagged words
@@ -121,7 +139,7 @@ def check_flagged_words(
cond = flagged_words_ratio <= flagged_words_threshold
return cond
-# %% ../nbs/01_filter.ipynb 15
+# %% ../nbs/01_filter.ipynb 17
multiprocessing.set_start_method("fork", force=True)
MINHASH_SEED = 115
@@ -132,7 +150,7 @@ def check_flagged_words(
lsh: MinHashLSH = None
dup_ids: set[int] = None
-# %% ../nbs/01_filter.ipynb 16
+# %% ../nbs/01_filter.ipynb 18
def _hash_func(
idx: int, # The index of the record.
content: str, # The content to be hashed.
@@ -154,7 +172,7 @@ def _hash_func(
m.update_batch([token.encode("utf-8") for token in {t for t in NON_ALPHA.split(content) if t}])
return {"__signature__": m.hashvalues, "__id__": idx}
-# %% ../nbs/01_filter.ipynb 18
+# %% ../nbs/01_filter.ipynb 20
def _query_content(
idx: int, # The index of the record.
signature: np.ndarray, # The MinHash signature of the record to be queried.
@@ -176,7 +194,7 @@ def _query_content(
"__id__": idx,
}
-# %% ../nbs/01_filter.ipynb 20
+# %% ../nbs/01_filter.ipynb 22
def _jaccard_similarity(
s1: str, # The first string to compare.
s2: str # The second string to compare.
@@ -188,7 +206,7 @@ def _jaccard_similarity(
tokens2 = set([t for t in NON_ALPHA.split(s2) if t.strip()])
return len(tokens1 & tokens2) / max(1, len(tokens1 | tokens2))
-# %% ../nbs/01_filter.ipynb 22
+# %% ../nbs/01_filter.ipynb 24
def _calculate_average_false_positive_rate(
clusters: list[list[int]], # The clusters of duplicate records.
reference_records: Dataset, # The reference records.
@@ -234,7 +252,7 @@ def _calculate_average_false_positive_rate(
logger.info(f"- Mean: {np.mean(deltas):0.2f}")
logger.info(f"- Std : {np.std(deltas):0.2f}")
-# %% ../nbs/01_filter.ipynb 23
+# %% ../nbs/01_filter.ipynb 25
def _find_duplicate_communities(
records: Dataset, # The dataset that contains both `__id__` and `__neighbors__`.
community_detection: bool, # Whether to use community detection to find the duplicate communities, or to use the connected components.
@@ -291,7 +309,7 @@ def _find_duplicate_communities(
return to_remove
-# %% ../nbs/01_filter.ipynb 24
+# %% ../nbs/01_filter.ipynb 26
def minhash_dedup(
ds, # The dataset to deduplicate.
column, # The column to use for deduplication.