diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 54f9789..a713081 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -65,18 +65,6 @@ jobs: runs-on: ubuntu-latest needs: validate-compute-block services: - minio: - image: lazybit/minio - ports: - - 9000:9000 - env: - MINIO_ROOT_USER: minioadmin - MINIO_ROOT_PASSWORD: minioadmin - options: >- - --health-cmd "curl -f http://localhost:9000/minio/health/live || exit 1" - --health-interval 5s - --health-retries 5 - --health-timeout 5s postgres: image: postgres:15 ports: diff --git a/algorithms/lda.py b/algorithms/lda.py index b916540..3364ee3 100644 --- a/algorithms/lda.py +++ b/algorithms/lda.py @@ -13,6 +13,7 @@ def __init__( self, dtm: np.ndarray = None, vocab: dict = None, + doc_ids: list[str] = [], n_topics: int = 10, max_iter: int = 10, learning_method: str = "batch", @@ -21,6 +22,12 @@ def __init__( ): self.dtm: np.ndarray = dtm self.vocab: dict = vocab + self.doc_ids = doc_ids + + if len(self.doc_ids) != self.dtm.shape[0]: + raise ValueError( + "doc_ids length must match number of DTM rows" + ) self.n_topics = n_topics self.max_iter = max_iter @@ -58,6 +65,7 @@ def extract_doc_topics(self) -> pd.DataFrame: self.doc_topic_dist, columns=[f"topic_{i}" for i in range(self.n_topics)], ) + df.insert(0, "doc_id", self.doc_ids) logger.debug( f"Extracted doc-topic distribution DataFrame shape={df.shape}") @@ -68,17 +76,28 @@ def extract_topic_terms(self): Generate topic and top-terms DataFrame """ logger.info("Extracting top terms per topic...") - idx2term = {idx: term for term, idx in self.vocab.items()} + + # NOTE: + # The order of `terms` is guaranteed to match the DTM column order. + # This is because the vocabulary is built in NLPVectorizer using: + # sorted_terms = sorted(all_terms) + # vocab = {term: i for i, term in enumerate(sorted_terms)} + # The same vocab indices are then used to construct the DTM columns. + # Since Python dicts preserve insertion order (>=3.7), + # list(self.vocab.keys())[i] correctly maps to DTM column i, + # and thus to lda.components_[topic_idx][i]. + terms = list(self.vocab.keys()) topic_rows = [] for topic_idx, topic in enumerate(self.lda.components_): - sorted = np.argsort(topic)[::-1] - top_indices = sorted[:self.n_top_words] + sorted_idx = np.argsort(topic)[::-1] + top_indices = sorted_idx[: self.n_top_words] + for i in top_indices: topic_rows.append({ "topic_id": topic_idx, - "term": idx2term[i], - "weight": topic[i] + "term": terms[int(i)], + "weight": topic[i], }) df = pd.DataFrame(topic_rows) diff --git a/algorithms/models.py b/algorithms/models.py new file mode 100644 index 0000000..598c980 --- /dev/null +++ b/algorithms/models.py @@ -0,0 +1,8 @@ +from typing import List +from dataclasses import dataclass + + +@dataclass +class PreprocessedDocument: + doc_id: str + tokens: List[str] diff --git a/algorithms/vectorizer.py b/algorithms/vectorizer.py new file mode 100644 index 0000000..e02ee66 --- /dev/null +++ b/algorithms/vectorizer.py @@ -0,0 +1,103 @@ +from typing import List +import numpy as np +from collections import Counter + +from algorithms.models import PreprocessedDocument + + +class NLPVectorizer: + def __init__(self, preprocessed_output: List[PreprocessedDocument]): + self.documents = preprocessed_output + self.doc_ids = [doc.doc_id for doc in preprocessed_output] + + # Frequencies + self.token_frequency = Counter() + self.token_document_frequency = Counter() + self.ngram_frequency = Counter() + self.ngram_document_frequency = Counter() + + # bow + dtm + self.bag_of_words = [] + self.vocab = {} + self.reverse_vocab = [] + self.dtm = None + + def analyze_frequencies(self): + for doc in self.documents: + tokens = [t for t in doc.tokens if " " not in t] + ngrams = [t for t in doc.tokens if " " in t] + + # token frequencies + self.token_frequency.update(tokens) + self.token_document_frequency.update(set(tokens)) + + # ngram frequencies + self.ngram_frequency.update(ngrams) + self.ngram_document_frequency.update(set(ngrams)) + + def build_bow(self): + bow = [] + + for doc in self.documents: + entries = [] + unique = set() + + for term in doc.tokens: + if term in unique: + continue + unique.add(term) + + is_ngram = " " in term + + entry = { + "term": term, + "type": "ngram" if is_ngram else "word", + "span": len(term.split(" ")), + "freq": ( + self.ngram_frequency[term] + if is_ngram + else self.token_frequency[term] + ), + "docs": ( + self.ngram_document_frequency[term] + if is_ngram + else self.token_document_frequency[term] + ), + "filters": [] + } + + entries.append(entry) + + bow.append(entries) + + self.bag_of_words = bow + return bow + + def build_vocabulary(self): + all_terms = set() + + for doc in self.documents: + for term in doc.tokens: + all_terms.add(term) + + sorted_terms = sorted(all_terms) + self.vocab = {term: i for i, term in enumerate(sorted_terms)} + self.reverse_vocab = sorted_terms + + return self.vocab + + def build_dtm(self): + if not self.vocab: + self.build_vocabulary() + + num_docs = len(self.documents) + num_terms = len(self.vocab) + + dtm = np.zeros((num_docs, num_terms), dtype=int) + + for i, doc in enumerate(self.documents): + for term in doc.tokens: + dtm[i, self.vocab[term]] += 1 + + self.dtm = dtm + return dtm diff --git a/cbc.yaml b/cbc.yaml index 8b9ea60..a98f0ce 100644 --- a/cbc.yaml +++ b/cbc.yaml @@ -1,39 +1,24 @@ author: Paul Kalhorn -description: Compute Block that offers Topic Modeling Algorihtms +description: Compute Block that offers topic modeling algorithms docker_image: ghcr.io/rwth-time/topic-modeling/topic-modeling entrypoints: lda_topic_modeling: - description: Sklearn LDA Topic Modeling + description: Sklearn LDA Topic Modeling envs: LEARNING_METHOD: batch MAX_ITER: 10 N_TOPICS: 5 N_TOP_WORDS: 10 inputs: - dtm: + preprocessed_docs: config: - dtm_BUCKET_NAME: null - dtm_FILE_EXT: pkl - dtm_FILE_NAME: null - dtm_FILE_PATH: null - dtm_S3_ACCESS_KEY: null - dtm_S3_HOST: null - dtm_S3_PORT: null - dtm_S3_SECRET_KEY: null - description: Pkl file of your numpy representation of the document-term matrix - type: file - vocab: - config: - vocab_BUCKET_NAME: null - vocab_FILE_EXT: pkl - vocab_FILE_NAME: null - vocab_FILE_PATH: null - vocab_S3_ACCESS_KEY: null - vocab_S3_HOST: null - vocab_S3_PORT: null - vocab_S3_SECRET_KEY: null - description: Pkl file of a dictionary that maps all words to their index in the DTM - type: file + preprocessed_docs_DB_TABLE: null + preprocessed_docs_PG_HOST: null + preprocessed_docs_PG_PASS: null + preprocessed_docs_PG_PORT: null + preprocessed_docs_PG_USER: null + description: A database table, expected to have the doc_id, and tokens (list of strings) + type: pg_table outputs: doc_topic: config: diff --git a/dtm.pkl b/dtm.pkl deleted file mode 100644 index 0ea1d88..0000000 Binary files a/dtm.pkl and /dev/null differ diff --git a/main.py b/main.py index aae6284..5dba9f5 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,18 @@ import logging -import pickle +import pandas as pd from scystream.sdk.core import entrypoint from scystream.sdk.env.settings import ( EnvSettings, InputSettings, OutputSettings, - FileSettings, PostgresSettings, ) -from scystream.sdk.file_handling.s3_manager import S3Operations from sqlalchemy import create_engine from algorithms.lda import LDAModeler +from algorithms.models import PreprocessedDocument +from algorithms.vectorizer import NLPVectorizer logging.basicConfig( level=logging.INFO, @@ -21,16 +21,8 @@ logger = logging.getLogger(__name__) -class DTMFileInput(FileSettings, InputSettings): - __identifier__ = "dtm" - - FILE_EXT: str = "pkl" - - -class VocabFileInput(FileSettings, InputSettings): - __identifier__ = "vocab" - - FILE_EXT: str = "pkl" +class PreprocessedDocuments(PostgresSettings, InputSettings): + __identifier__ = "preprocessed_docs" class DocTopicOutput(PostgresSettings, OutputSettings): @@ -47,50 +39,56 @@ class LDATopicModeling(EnvSettings): LEARNING_METHOD: str = "batch" N_TOP_WORDS: int = 10 - vocab: VocabFileInput - dtm: DTMFileInput + preprocessed_docs: PreprocessedDocuments doc_topic: DocTopicOutput topic_term: TopicTermsOutput -def write_df_to_postgres(df, settings: PostgresSettings): - logger.info(f"Writing DataFrame to DB table '{settings.DB_TABLE}'…") - - engine = create_engine( +def _make_engine(settings: PostgresSettings): + return create_engine( f"postgresql+psycopg2://{settings.PG_USER}:{settings.PG_PASS}" f"@{settings.PG_HOST}:{int(settings.PG_PORT)}/" ) + + +def write_df_to_postgres(df, settings: PostgresSettings): + logger.info(f"Writing DataFrame to DB table '{settings.DB_TABLE}'…") + engine = _make_engine(settings) df.to_sql(settings.DB_TABLE, engine, if_exists="replace", index=False) logger.info(f"Successfully wrote {len(df)} rows to '{settings.DB_TABLE}'.") +def read_table_from_postgres(settings: PostgresSettings) -> pd.DataFrame: + engine = _make_engine(settings) + query = f"SELECT * FROM {settings.DB_TABLE} ORDER BY doc_id;" + return pd.read_sql(query, engine) + + @entrypoint(LDATopicModeling) def lda_topic_modeling(settings): logger.info("Starting LDA topic modeling pipeline…") - logger.info("Downloading vocabulary file...") - S3Operations.download(settings.vocab, "vocab.pkl") - - logger.info("Loading vocab.pkl from disk...") - with open("vocab.pkl", "rb") as f: - vocab = pickle.load(f) + logger.info("Querying normalized docs from db...") + normalized_docs = read_table_from_postgres(settings.preprocessed_docs) - logger.info(f"Loaded vocab with {len(vocab)} terms.") - - logger.info("Downloading DTM file...") - S3Operations.download(settings.dtm, "dtm.pkl") - - logger.info("Loading dtm.pkl from disk...") - with open("dtm.pkl", "rb") as f: - dtm = pickle.load(f) + preprocessed_docs = [ + PreprocessedDocument( + doc_id=row["doc_id"], + tokens=row["tokens"] + ) + for _, row in normalized_docs.iterrows() + ] - logger.info(f"Loaded DTM with shape {dtm.shape}") + vectorizer = NLPVectorizer(preprocessed_docs) + vectorizer.analyze_frequencies() + vocab = vectorizer.build_vocabulary() + dtm = vectorizer.build_dtm() - # TODO: Check if dtm and vocab is of correct schema lda = LDAModeler( dtm=dtm, vocab=vocab, + doc_ids=vectorizer.doc_ids, n_topics=settings.N_TOPICS, max_iter=settings.MAX_ITER, learning_method=settings.LEARNING_METHOD, @@ -105,44 +103,3 @@ def lda_topic_modeling(settings): # TODO: Use Spark Integration here write_df_to_postgres(doc_topics, settings.doc_topic) write_df_to_postgres(topic_terms, settings.topic_term) - - -""" -if __name__ == "__main__": - test = LDATopicModeling( - vocab=VocabFileInput( - S3_HOST="http://localhost", - S3_PORT="9000", - S3_ACCESS_KEY="minioadmin", - S3_SECRET_KEY="minioadmin", - BUCKET_NAME="output-bucket", - FILE_PATH="output_file_path", - FILE_NAME="vocab_file_bib", - ), - dtm=DTMFileInput( - S3_HOST="http://localhost", - S3_PORT="9000", - S3_ACCESS_KEY="minioadmin", - S3_SECRET_KEY="minioadmin", - BUCKET_NAME="output-bucket", - FILE_PATH="output_file_path", - FILE_NAME="dtm_file_bib", - ), - doc_topic=DocTopicOutput( - PG_USER="postgres", - PG_PASS="postgres", - PG_HOST="localhost", - PG_PORT="5432", - DB_TABLE="doc_topic" - ), - topic_term=TopicTermsOutput( - PG_USER="postgres", - PG_PASS="postgres", - PG_HOST="localhost", - PG_PORT="5432", - DB_TABLE="topic_term" - ) - ) - - lda_topic_modeling(test) -""" diff --git a/test/files/dtm.pkl b/test/files/dtm.pkl deleted file mode 100644 index 0ea1d88..0000000 Binary files a/test/files/dtm.pkl and /dev/null differ diff --git a/test/files/norm_docs_dump.sql b/test/files/norm_docs_dump.sql new file mode 100644 index 0000000..384d9d1 --- /dev/null +++ b/test/files/norm_docs_dump.sql @@ -0,0 +1,34 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 13.23 (Debian 13.23-1.pgdg13+1) +-- Dumped by pg_dump version 13.23 (Debian 13.23-1.pgdg13+1) + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +-- +-- Data for Name: normalized_docs_bib; Type: TABLE DATA; Schema: public; Owner: postgres +-- + +CREATE TABLE public.norm_docs ( + doc_id TEXT PRIMARY KEY, + tokens TEXT[] +); + +INSERT INTO public.norm_docs (doc_id, tokens) VALUES ('WOS:001016714700004', '{modern,contemporari,transhuman,seen,recent,rise,academ,popular,relev,specif,naiv,metaphys,idea,immort,return,rise,"modern contemporari","contemporari transhuman","transhuman seen","seen recent","recent rise","rise academ","academ popular","popular relev","relev specif","specif naiv","naiv metaphys","metaphys idea","idea immort","immort return","return rise","modern contemporari transhuman","contemporari transhuman seen","transhuman seen recent","seen recent rise","recent rise academ","rise academ popular","academ popular relev","popular relev specif","relev specif naiv","specif naiv metaphys","naiv metaphys idea","metaphys idea immort","idea immort return","immort return rise",articl,refrain,ethic,polit,assess,transhuman,"articl refrain","refrain ethic","ethic polit","polit assess","assess transhuman","articl refrain ethic","refrain ethic polit","ethic polit assess","polit assess transhuman",critiqu,exact,metaphys,idealist,natur,transhuman,pursuit,digit,immort,idea,technolog,advanc,precis,artifici,gener,intellig,immort,virtual,self,possibl,"critiqu exact","exact metaphys","metaphys idealist","idealist natur","natur transhuman","transhuman pursuit","pursuit digit","digit immort","immort idea","idea technolog","technolog advanc","advanc precis","precis artifici","artifici gener","gener intellig","intellig immort","immort virtual","virtual self","self possibl","critiqu exact metaphys","exact metaphys idealist","metaphys idealist natur","idealist natur transhuman","natur transhuman pursuit","transhuman pursuit digit","pursuit digit immort","digit immort idea","immort idea technolog","idea technolog advanc","technolog advanc precis","advanc precis artifici","precis artifici gener","artifici gener intellig","gener intellig immort","intellig immort virtual","immort virtual self","virtual self possibl",articl,follow,form,immanuel,kant,paralog,critiqu,pure,reason,kant,concern,substanti,immort,natur,soul,experienti,imposs,"articl follow","follow form","form immanuel","immanuel kant","kant paralog","paralog critiqu","critiqu pure","pure reason","reason kant","kant concern","concern substanti","substanti immort","immort natur","natur soul","soul experienti","experienti imposs","articl follow form","follow form immanuel","form immanuel kant","immanuel kant paralog","kant paralog critiqu","paralog critiqu pure","critiqu pure reason","pure reason kant","reason kant concern","kant concern substanti","concern substanti immort","substanti immort natur","immort natur soul","natur soul experienti","soul experienti imposs",articl,offer,theoret,practic,paralog,fals,logic,infer,argu,transhumanist,claim,digit,immort,possibl,fundament,stem,incorrect,major,premis,"articl offer","offer theoret","theoret practic","practic paralog","paralog fals","fals logic","logic infer","infer argu","argu transhumanist","transhumanist claim","claim digit","digit immort","immort possibl","possibl fundament","fundament stem","stem incorrect","incorrect major","major premis","articl offer theoret","offer theoret practic","theoret practic paralog","practic paralog fals","paralog fals logic","fals logic infer","logic infer argu","infer argu transhumanist","argu transhumanist claim","transhumanist claim digit","claim digit immort","digit immort possibl","immort possibl fundament","possibl fundament stem","fundament stem incorrect","stem incorrect major","incorrect major premis",concern,substanti,natur,inform,inform,theoret,paralog,second,concern,infinit,transform,pure,plastic,inform,practic,paralog,"concern substanti","substanti natur","natur inform","inform inform","inform theoret","theoret paralog","paralog second","second concern","concern infinit","infinit transform","transform pure","pure plastic","plastic inform","inform practic","practic paralog","concern substanti natur","substanti natur inform","natur inform inform","inform inform theoret","inform theoret paralog","theoret paralog second","paralog second concern","second concern infinit","concern infinit transform","infinit transform pure","transform pure plastic","pure plastic inform","plastic inform practic","inform practic paralog"}'); +INSERT INTO public.norm_docs (doc_id, tokens) VALUES ('WOS:001322577100012', '{unit,nation,panel,digit,cooper,emphas,inclus,growth,digit,network,digit,public,good,util,multistakehold,system,approach,"unit nation","nation panel","panel digit","digit cooper","cooper emphas","emphas inclus","inclus growth","growth digit","digit network","network digit","digit public","public good","good util","util multistakehold","multistakehold system","system approach","unit nation panel","nation panel digit","panel digit cooper","digit cooper emphas","cooper emphas inclus","emphas inclus growth","inclus growth digit","growth digit network","digit network digit","network digit public","digit public good","public good util","good util multistakehold","util multistakehold system","multistakehold system approach",similarli,inform,commun,technolog,ICT,innov,intervent,program,govern,india,digit,north,east,vision,emphas,need,inclus,growth,ICT,northeast,region,"similarli inform","inform commun","commun technolog","technolog ICT","ICT innov","innov intervent","intervent program","program govern","govern india","india digit","digit north","north east","east vision","vision emphas","emphas need","need inclus","inclus growth","growth ICT","ICT northeast","northeast region","similarli inform commun","inform commun technolog","commun technolog ICT","technolog ICT innov","ICT innov intervent","innov intervent program","intervent program govern","program govern india","govern india digit","india digit north","digit north east","north east vision","east vision emphas","vision emphas need","emphas need inclus","need inclus growth","inclus growth ICT","growth ICT northeast","ICT northeast region",line,articl,present,insight,field,studi,conduct,rural,part,manipur,india,incident,found,applic,rural,part,develop,world,"line articl","articl present","present insight","insight field","field studi","studi conduct","conduct rural","rural part","part manipur","manipur india","india incident","incident found","found applic","applic rural","rural part","part develop","develop world","line articl present","articl present insight","present insight field","insight field studi","field studi conduct","studi conduct rural","conduct rural part","rural part manipur","part manipur india","manipur india incident","india incident found","incident found applic","found applic rural","applic rural part","rural part develop","part develop world",articl,envis,commun,driven,sociodigit,transform,northeast,region,india,"articl envis","envis commun","commun driven","driven sociodigit","sociodigit transform","transform northeast","northeast region","region india","articl envis commun","envis commun driven","commun driven sociodigit","driven sociodigit transform","sociodigit transform northeast","transform northeast region","northeast region india",quest,articl,highlight,sociopolit,challeng,digit,transform,provid,insight,inclus,ICT,region,infrastructur,util,citizen,smart,govern,servic,demand,digit,empower,citizen,social,welfar,capac,build,commun,engag,"quest articl","articl highlight","highlight sociopolit","sociopolit challeng","challeng digit","digit transform","transform provid","provid insight","insight inclus","inclus ICT","ICT region","region infrastructur","infrastructur util","util citizen","citizen smart","smart govern","govern servic","servic demand","demand digit","digit empower","empower citizen","citizen social","social welfar","welfar capac","capac build","build commun","commun engag","quest articl highlight","articl highlight sociopolit","highlight sociopolit challeng","sociopolit challeng digit","challeng digit transform","digit transform provid","transform provid insight","provid insight inclus","insight inclus ICT","inclus ICT region","ICT region infrastructur","region infrastructur util","infrastructur util citizen","util citizen smart","citizen smart govern","smart govern servic","govern servic demand","servic demand digit","demand digit empower","digit empower citizen","empower citizen social","citizen social welfar","social welfar capac","welfar capac build","capac build commun","build commun engag"}'); + + +-- +-- PostgreSQL database dump complete +-- diff --git a/test/files/vocab.pkl b/test/files/vocab.pkl deleted file mode 100644 index eeaed7f..0000000 Binary files a/test/files/vocab.pkl and /dev/null differ diff --git a/test/test_lda_entrypoint.py b/test/test_lda_entrypoint.py index c3d92ba..7087b70 100644 --- a/test/test_lda_entrypoint.py +++ b/test/test_lda_entrypoint.py @@ -1,12 +1,10 @@ import os -import boto3 import pytest import psycopg2 import time import pandas as pd from pathlib import Path -from botocore.exceptions import ClientError from main import lda_topic_modeling MINIO_USER = "minioadmin" @@ -19,35 +17,6 @@ N_TOPICS = 5 -def ensure_bucket(s3, bucket): - try: - s3.head_bucket(Bucket=bucket) - except ClientError as e: - error_code = e.response["Error"]["Code"] - if error_code in ("404", "NoSuchBucket"): - s3.create_bucket(Bucket=bucket) - else: - raise - - -def download_to_tmp(s3, bucket, key): - tmp_path = Path("/tmp") / key.replace("/", "_") - s3.download_file(bucket, key, str(tmp_path)) - return tmp_path - - -@pytest.fixture -def s3_minio(): - client = boto3.client( - "s3", - endpoint_url="http://localhost:9000", - aws_access_key_id=MINIO_USER, - aws_secret_access_key=MINIO_PWD - ) - ensure_bucket(client, BUCKET_NAME) - return client - - @pytest.fixture(scope="session") def postgres_conn(): """Wait until postgres is ready, then yield a live connection.""" @@ -69,51 +38,27 @@ def postgres_conn(): raise RuntimeError("Postgres did not start") -def test_lda_entrypoint(s3_minio, postgres_conn): - input_dtm_file_name = "dtm" - input_vocab_file_name = "vocab" - +def test_lda_entrypoint(postgres_conn): doc_topic_table_name = "doc_topic" topic_terms_table_name = "topic_terms" - dtm_path = Path(__file__).parent / "files" / f"{input_dtm_file_name}.pkl" - dtm_bytes = dtm_path.read_bytes() - - vocab_path = Path(__file__).parent / "files" / \ - f"{input_vocab_file_name}.pkl" - vocab_bytes = vocab_path.read_bytes() + sql_dump_path = Path(__file__).parent / "files" / "norm_docs_dump.sql" + with open(sql_dump_path, "r") as f: + sql = f.read() - s3_minio.put_object( - Bucket=BUCKET_NAME, - Key=f"{input_dtm_file_name}.pkl", - Body=dtm_bytes - ) - s3_minio.put_object( - Bucket=BUCKET_NAME, - Key=f"{input_vocab_file_name}.pkl", - Body=vocab_bytes - ) + cur = postgres_conn.cursor() + cur.execute("DROP TABLE IF EXISTS norm_docs;") + cur.execute(sql) env = { "N_TOPICS": "5", - "dtm_S3_HOST": "http://127.0.0.1", - "dtm_S3_PORT": "9000", - "dtm_S3_ACCESS_KEY": MINIO_USER, - "dtm_S3_SECRET_KEY": MINIO_PWD, - "dtm_BUCKET_NAME": BUCKET_NAME, - "dtm_FILE_PATH": "", - "dtm_FILE_NAME": input_dtm_file_name, - "dtm_FILE_EXT": "pkl", - - "vocab_S3_HOST": "http://127.0.0.1", - "vocab_S3_PORT": "9000", - "vocab_S3_ACCESS_KEY": MINIO_USER, - "vocab_S3_SECRET_KEY": MINIO_PWD, - "vocab_BUCKET_NAME": BUCKET_NAME, - "vocab_FILE_PATH": "", - "vocab_FILE_NAME": input_vocab_file_name, - "vocab_FILE_EXT": "pkl", + "preprocessed_docs_PG_HOST": "127.0.0.1", + "preprocessed_docs_PG_PORT": "5432", + "preprocessed_docs_PG_USER": POSTGRES_USER, + "preprocessed_docs_PG_PASS": POSTGRES_PWD, + "preprocessed_docs_DB_TABLE": "norm_docs", + "docs_to_topics_PG_HOST": "127.0.0.1", "docs_to_topics_PG_PORT": "5432", @@ -136,15 +81,16 @@ def test_lda_entrypoint(s3_minio, postgres_conn): cur = postgres_conn.cursor() # 1. doc-topic distribution - cur.execute(f"SELECT * FROM {doc_topic_table_name} ORDER BY 1;") + cur.execute(f"SELECT * FROM public.{doc_topic_table_name} ORDER BY 1;") doc_topics = pd.DataFrame(cur.fetchall(), columns=[ desc[0] for desc in cur.description]) - assert len(doc_topics) == 26 - assert doc_topics.shape[1] == N_TOPICS + assert len(doc_topics) == 2 + # expect N_TOPICS + 1 for doc_id + assert doc_topics.shape[1] == N_TOPICS + 1 # 2. topic-term listing cur.execute( - f"SELECT * FROM { + f"SELECT * FROM public.{ topic_terms_table_name} ORDER BY topic_id, weight DESC;") topic_terms = pd.DataFrame(cur.fetchall(), columns=[ desc[0] for desc in cur.description]) diff --git a/test/test_lda_modeler.py b/test/test_lda_modeler.py index e3cdae3..77b71f2 100644 --- a/test/test_lda_modeler.py +++ b/test/test_lda_modeler.py @@ -22,25 +22,40 @@ def small_dtm(): def test_lda_fit(small_dtm, small_vocab): - lda = LDAModeler(dtm=small_dtm, vocab=small_vocab, n_topics=2) + lda = LDAModeler( + dtm=small_dtm, + vocab=small_vocab, + doc_ids=["0", "1", "2"], + n_topics=2, + ) lda.fit() assert lda.lda.components_.shape == (2, 4) def test_extract_doc_topics(small_dtm, small_vocab): - lda = LDAModeler(dtm=small_dtm, vocab=small_vocab, n_topics=2) + lda = LDAModeler( + dtm=small_dtm, + vocab=small_vocab, + doc_ids=["0", "1", "2"], + n_topics=2, + ) lda.fit() df = lda.extract_doc_topics() assert isinstance(df, pd.DataFrame) - assert df.shape == (3, 2) - assert list(df.columns) == ["topic_0", "topic_1"] + assert df.shape == (3, 3) + assert list(df.columns) == ["doc_id", "topic_0", "topic_1"] def test_extract_topic_terms(small_dtm, small_vocab): - lda = LDAModeler(dtm=small_dtm, vocab=small_vocab, - n_topics=2, n_top_words=2) + lda = LDAModeler( + dtm=small_dtm, + vocab=small_vocab, + doc_ids=["1", "2", "3"], + n_topics=2, + n_top_words=2 + ) lda.fit() df = lda.extract_topic_terms() diff --git a/vocab.pkl b/vocab.pkl deleted file mode 100644 index eeaed7f..0000000 Binary files a/vocab.pkl and /dev/null differ