Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ jobs:

- name: Run flake8
uses: py-actions/flake8@v2
with:
args: --exclude=test

validate-compute-block:
name: Validate Compute Block Config
Expand Down Expand Up @@ -94,7 +96,7 @@ jobs:
env:
OLLAMA_API_KEY: ${{ secrets.OLLAMA_API_KEY }}
run: pytest -vv

build:
name: Build docker image
runs-on: ubuntu-latest
Expand All @@ -121,7 +123,7 @@ jobs:
tags: |
type=ref, event=pr
type=raw, value=latest, enable=${{ (github.ref == format('refs/heads/{0}', 'main')) }}

- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
Expand Down
52 changes: 17 additions & 35 deletions cbc.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
author: Paul Kalhorn
author: Paul Kalhorn
description: Compute Block that offers topic modeling algorithms
docker_image: ghcr.io/rwth-time/topic-modeling/topic-modeling
entrypoints:
lda_topic_modeling:
description: Sklearn LDA Topic Modeling
description: Sklearn LDA Topic Modeling
envs:
LEARNING_METHOD: batch
MAX_ITER: 10
Expand All @@ -12,64 +12,46 @@ entrypoints:
inputs:
preprocessed_docs:
config:
preprocessed_docs_DB_DSN: null
preprocessed_docs_DB_TABLE: null
preprocessed_docs_PG_HOST: null
preprocessed_docs_PG_PASS: null
preprocessed_docs_PG_PORT: null
preprocessed_docs_PG_USER: null
description: A database table, expected to have the doc_id, and tokens (list of strings)
type: pg_table
description: A database table, expected to have the doc_id, and tokens (list of strings)
type: database_table
outputs:
doc_topic:
config:
docs_to_topics_DB_DSN: null
docs_to_topics_DB_TABLE: null
docs_to_topics_PG_HOST: null
docs_to_topics_PG_PASS: null
docs_to_topics_PG_PORT: null
docs_to_topics_PG_USER: null
description: A table that maps documents to their topic-likelihoods
type: pg_table
type: database_table
topic_term:
config:
top_terms_per_topic_DB_DSN: null
top_terms_per_topic_DB_TABLE: null
top_terms_per_topic_PG_HOST: null
top_terms_per_topic_PG_PASS: null
top_terms_per_topic_PG_PORT: null
top_terms_per_topic_PG_USER: null
description: A table that lists most likely terms for a topic
type: pg_table
type: database_table
topic_explanation:
description: Explains the topics using the query and top terms
description: Explains the topics using the query and top terms
envs:
MODEL_NAME: gpt-oss:120b
OLLAMA_API_KEY: ''
inputs:
query_information:
config:
query_information_input_DB_DSN: null
query_information_input_DB_TABLE: null
query_information_input_PG_HOST: null
query_information_input_PG_PASS: null
query_information_input_PG_PORT: null
query_information_input_PG_USER: null
description: Information of the query used, must contain query, source, created_at
type: pg_table
description: Information of the query used, must contain query, source, created_at
type: database_table
topic_terms:
config:
topic_terms_input_DB_DSN: null
topic_terms_input_DB_TABLE: null
topic_terms_input_PG_HOST: null
topic_terms_input_PG_PASS: null
topic_terms_input_PG_PORT: null
topic_terms_input_PG_USER: null
description: A table that lists most likely terms for a topic
type: pg_table
type: database_table
outputs:
explanations_output:
config:
explanations_output_DB_DSN: null
explanations_output_DB_TABLE: null
explanations_output_PG_HOST: null
explanations_output_PG_PASS: null
explanations_output_PG_PORT: null
explanations_output_PG_USER: null
description: Output of the generated explanations, contains topic_id and description
type: pg_table
type: database_table
name: Topic-Modeling
115 changes: 48 additions & 67 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import hashlib
import pandas as pd

from scystream.sdk.core import entrypoint
from scystream.sdk.database_handling.database_manager import (
PandasDatabaseOperations,
)
from scystream.sdk.env.settings import (
EnvSettings,
InputSettings,
OutputSettings,
PostgresSettings,
DatabaseSettings,
)
from sqlalchemy import create_engine, text
from sqlalchemy.sql import quoted_name

from algorithms.lda import LDAModeler
from algorithms.models import PreprocessedDocument
Expand All @@ -20,35 +19,20 @@

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


def _normalize_table_name(table_name: str) -> str:
max_length = 63
if len(table_name) <= max_length:
return table_name
digest = hashlib.sha1(table_name.encode("utf-8")).hexdigest()[:10]
prefix_length = max_length - len(digest) - 1
return f"{table_name[:prefix_length]}_{digest}"


def _resolve_db_table(settings: PostgresSettings) -> str:
normalized_name = _normalize_table_name(settings.DB_TABLE)
settings.DB_TABLE = normalized_name
return normalized_name


class PreprocessedDocuments(PostgresSettings, InputSettings):
class PreprocessedDocuments(DatabaseSettings, InputSettings):
__identifier__ = "preprocessed_docs"


class DocTopicOutput(PostgresSettings, OutputSettings):
class DocTopicOutput(DatabaseSettings, OutputSettings):
__identifier__ = "docs_to_topics"


class TopicTermsOutput(PostgresSettings, OutputSettings):
class TopicTermsOutput(DatabaseSettings, OutputSettings):
__identifier__ = "top_terms_per_topic"


Expand All @@ -64,21 +48,22 @@ class LDATopicModeling(EnvSettings):
topic_term: TopicTermsOutput


class TopicTermsInput(PostgresSettings, InputSettings):
class TopicTermsInput(DatabaseSettings, InputSettings):
__identifier__ = "topic_terms_input"


class QueryInformationInput(PostgresSettings, InputSettings):
class QueryInformationInput(DatabaseSettings, InputSettings):
"""
Our TopicExplaination needs some kind of information about the actual
query executed,this query information includes the query and the source

Looking like: query, source, created_at
"""

__identifier__ = "query_information_input"


class ExplanationsOutput(PostgresSettings, OutputSettings):
class ExplanationsOutput(DatabaseSettings, OutputSettings):
__identifier__ = "explanations_output"


Expand All @@ -92,33 +77,6 @@ class TopicExplanation(EnvSettings):
explanations_output: ExplanationsOutput


def _make_engine(settings: PostgresSettings):
return create_engine(
f"postgresql+psycopg2://{settings.PG_USER}:{settings.PG_PASS}"
f"@{settings.PG_HOST}:{int(settings.PG_PORT)}/"
)


def write_df_to_postgres(df, settings: PostgresSettings):
resolved_table_name = _resolve_db_table(settings)
logger.info(f"Writing DataFrame to DB table '{resolved_table_name}'…")
engine = _make_engine(settings)
table_name = quoted_name(resolved_table_name, quote=True)
df.to_sql(table_name, engine, if_exists="replace", index=False)
logger.info(
"Successfully wrote %s rows to '%s'.",
len(df),
resolved_table_name,
)


def read_table_from_postgres(settings: PostgresSettings) -> pd.DataFrame:
resolved_table_name = _resolve_db_table(settings)
engine = _make_engine(settings)
query = text(f'SELECT * FROM "{resolved_table_name}";')
return pd.read_sql(query, engine)


def parse_pg_array(val):
if isinstance(val, str):
return val.strip("{}").split(",")
Expand All @@ -130,12 +88,16 @@ def lda_topic_modeling(settings):
logger.info("Starting LDA topic modeling pipeline…")

logger.info("Querying normalized docs from db...")
normalized_docs = read_table_from_postgres(settings.preprocessed_docs)
preprocessed_docs_db = PandasDatabaseOperations(
settings.preprocessed_docs.DB_DSN
)
normalized_docs = preprocessed_docs_db.read(
table=settings.preprocessed_docs.DB_TABLE
)

preprocessed_docs = [
PreprocessedDocument(
doc_id=row["doc_id"],
tokens=parse_pg_array(row["tokens"])
doc_id=row["doc_id"], tokens=parse_pg_array(row["tokens"])
)
for _, row in normalized_docs.iterrows()
]
Expand All @@ -153,42 +115,61 @@ def lda_topic_modeling(settings):
max_iter=settings.MAX_ITER,
learning_method=settings.LEARNING_METHOD,
random_state=42,
n_top_words=settings.N_TOP_WORDS
n_top_words=settings.N_TOP_WORDS,
)
lda.fit()

doc_topics = lda.extract_doc_topics()
topic_terms = lda.extract_topic_terms()

# TODO: Use Spark Integration here
write_df_to_postgres(doc_topics, settings.doc_topic)
write_df_to_postgres(topic_terms, settings.topic_term)
logging.info("Writing dataframes to db...")
doc_topic_db = PandasDatabaseOperations(settings.doc_topic.DB_DSN)
topic_terms_db = PandasDatabaseOperations(settings.topic_term.DB_DSN)

doc_topic_db.write(
table=settings.doc_topic.DB_TABLE, data=doc_topics, mode="overwrite"
)
topic_terms_db.write(
table=settings.topic_term.DB_TABLE, data=topic_terms, mode="overwrite"
)


@entrypoint(TopicExplanation)
def topic_explanation(settings):
logger.info("Starting topic explaination...")

logging.info("Querying topic terms from db...")
topic_terms = read_table_from_postgres(settings.topic_terms)
topic_terms_db = PandasDatabaseOperations(settings.topic_terms.DB_DSN)
topic_terms = topic_terms_db.read(table=settings.topic_terms.DB_TABLE)

logging.info("Querying query information from db...")
query_information = read_table_from_postgres(settings.query_information)
query_info_db = PandasDatabaseOperations(settings.query_information.DB_DSN)
query_information = query_info_db.read(
table=settings.query_information.DB_TABLE
)

metadata = query_information.iloc[0]
print(metadata)

explainer = TopicExplainer(
model_name=settings.MODEL_NAME,
api_key=settings.OLLAMA_API_KEY
model_name=settings.MODEL_NAME, api_key=settings.OLLAMA_API_KEY
)

explainations = explainer.explain_topics(
explanations = explainer.explain_topics(
topic_terms=topic_terms,
search_query=metadata["query"],
source=metadata["source"],
created_at=metadata["created_at"]
created_at=metadata["created_at"],
)

write_df_to_postgres(explainations, settings.explanations_output)
explainations_output_db = PandasDatabaseOperations(
settings.explanations_output.DB_DSN
)
explainations_output_db.write(
table=settings.explanations_output.DB_TABLE,
data=explanations,
mode="overwrite",
)

logging.info("Topic explanation block finished.")
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
scystream-sdk==1.2.2
scystream-sdk[database,postgres]==1.4.0
scikit-learn==1.7.2
pandas==2.3.2
numpy==2.3.3
Expand Down
28 changes: 8 additions & 20 deletions test/test_explanation_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def postgres_conn():
port=5432,
user="postgres",
password="postgres",
database="postgres"
database="postgres",
)
conn.autocommit = True
yield conn
Expand Down Expand Up @@ -84,26 +84,14 @@ def test_topic_explanation_with_real_ollama(postgres_conn):
# ------------------------------------------------------------------

env = {
"topic_terms_input_PG_HOST": "127.0.0.1",
"topic_terms_input_PG_PORT": "5432",
"topic_terms_input_PG_USER": "postgres",
"topic_terms_input_PG_PASS": "postgres",
"topic_terms_input_DB_DSN": "postgresql://postgres:postgres@127.0.0.1:5432/postgres",
"topic_terms_input_DB_TABLE": topic_terms_table,

"query_information_input_PG_HOST": "127.0.0.1",
"query_information_input_PG_PORT": "5432",
"query_information_input_PG_USER": "postgres",
"query_information_input_PG_PASS": "postgres",
"query_information_input_DB_DSN": "postgresql://postgres:postgres@127.0.0.1:5432/postgres",
"query_information_input_DB_TABLE": query_information_table,

"explanations_output_PG_HOST": "127.0.0.1",
"explanations_output_PG_PORT": "5432",
"explanations_output_PG_USER": "postgres",
"explanations_output_PG_PASS": "postgres",
"explanations_output_DB_DSN": "postgresql://postgres:postgres@127.0.0.1:5432/postgres",
"explanations_output_DB_TABLE": explanations_output_table,

"MODEL_NAME": "gpt-oss:120b",
"API_KEY": os.environ.get("OLLAMA_API_KEY")
"API_KEY": os.environ.get("OLLAMA_API_KEY"),
}

for k, v in env.items():
Expand All @@ -120,10 +108,10 @@ def test_topic_explanation_with_real_ollama(postgres_conn):
# ------------------------------------------------------------------

cur.execute(
f"SELECT * FROM public.{explanations_output_table} ORDER BY topic_id;")
f"SELECT * FROM public.{explanations_output_table} ORDER BY topic_id;"
)
results = pd.DataFrame(
cur.fetchall(),
columns=[desc[0] for desc in cur.description]
cur.fetchall(), columns=[desc[0] for desc in cur.description]
)

assert len(results) == 5
Expand Down
Loading
Loading