Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ jobs:
pip install -r requirements.txt

- name: Run Tests
env:
OLLAMA_API_KEY: ${{ secrets.OLLAMA_API_KEY }}
run: pytest -vv

build:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.env

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
Expand Down
107 changes: 107 additions & 0 deletions algorithms/explanations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import logging
import pandas as pd
from ollama import Client

logger = logging.getLogger(__name__)


class TopicExplainer:
def __init__(
self,
api_key: str,
model_name: str = "gpt-oss:120b",
timeout: int = 60,
):
self.model_name = model_name
self.timeout = timeout

self.client = Client(
host="https://ollama.com",
headers={
"Authorization": "Bearer " + api_key
},
timeout=self.timeout,
)

def explain_topics(
self,
topic_terms: pd.DataFrame,
search_query: str,
source: str,
created_at: str,
) -> pd.DataFrame:
logger.info("Generating topic explanations...")

rows = []

for topic_id, group in topic_terms.groupby("topic_id"):
terms = (
group.sort_values("weight", ascending=False)["term"]
.astype(str)
.tolist()
)

prompt = self._build_prompt(
topic_id=topic_id,
terms=terms,
search_query=search_query,
source=source,
created_at=created_at,
)

description = self._call_ollama(prompt)

rows.append(
{
"topic_id": topic_id,
"description": description,
}
)

df = pd.DataFrame(rows)

logger.info(f"Generated explanations for {len(df)} topics")
return df

def _build_prompt(
self,
topic_id: int,
terms: list[str],
search_query: str,
source: str,
created_at: str,
) -> str:
terms_str = ", ".join(terms)
date_info = f"Creation date: {created_at}"

return f"""
Context:
The documents come from {source}

Search Query:
"{search_query}"
{date_info}

Topic ID: {topic_id}

Top keywords for this topic:
{terms_str}

Task:
Describe in 1–2 concise sentences what this topic represents.
Focus on the research subfield or thematic area.
Do not list the keywords explicitly.
""".strip()

def _call_ollama(self, prompt: str) -> str:
logger.debug("Calling Ollama Cloud...")

response = self.client.chat(
model=self.model_name,
messages=[
{"role": "user", "content": prompt}
],
stream=False,
)

return response["message"]["content"].strip()
38 changes: 36 additions & 2 deletions cbc.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
author: Paul Kalhorn
author: Paul Kalhorn
description: Compute Block that offers topic modeling algorithms
docker_image: ghcr.io/rwth-time/topic-modeling/topic-modeling
entrypoints:
lda_topic_modeling:
description: Sklearn LDA Topic Modeling
description: Sklearn LDA Topic Modeling
envs:
LEARNING_METHOD: batch
MAX_ITER: 10
Expand Down Expand Up @@ -38,4 +38,38 @@ entrypoints:
top_terms_per_topic_PG_USER: null
description: A table that lists most likely terms for a topic
type: pg_table
topic_explanation:
description: Explains the topics using the query and top terms
envs:
MODEL_NAME: gpt-oss:120b
OLLAMA_API_KEY: ''
inputs:
query_information:
config:
query_information_input_DB_TABLE: null
query_information_input_PG_HOST: null
query_information_input_PG_PASS: null
query_information_input_PG_PORT: null
query_information_input_PG_USER: null
description: Information of the query used, must contain query, source, created_at
type: pg_table
topic_terms:
config:
topic_terms_input_DB_TABLE: null
topic_terms_input_PG_HOST: null
topic_terms_input_PG_PASS: null
topic_terms_input_PG_PORT: null
topic_terms_input_PG_USER: null
description: A table that lists most likely terms for a topic
type: pg_table
outputs:
explanations_output:
config:
explanations_output_DB_TABLE: null
explanations_output_PG_HOST: null
explanations_output_PG_PASS: null
explanations_output_PG_PORT: null
explanations_output_PG_USER: null
description: Output of the generated explanations, contains topic_id and description
type: pg_table
name: Topic-Modeling
61 changes: 60 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from algorithms.models import PreprocessedDocument
from algorithms.vectorizer import NLPVectorizer

from algorithms.explanations import TopicExplainer

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
Expand Down Expand Up @@ -45,6 +47,34 @@ class LDATopicModeling(EnvSettings):
topic_term: TopicTermsOutput


class TopicTermsInput(PostgresSettings, InputSettings):
__identifier__ = "topic_terms_input"


class QueryInformationInput(PostgresSettings, InputSettings):
"""
Our TopicExplaination needs some kind of information about the actual
query executed,this query information includes the query and the source

Looking like: query, source, created_at
"""
__identifier__ = "query_information_input"


class ExplanationsOutput(PostgresSettings, OutputSettings):
__identifier__ = "explanations_output"


class TopicExplanation(EnvSettings):
MODEL_NAME: str = "gpt-oss:120b"
OLLAMA_API_KEY: str = ""

topic_terms: TopicTermsInput
query_information: QueryInformationInput

explanations_output: ExplanationsOutput


def _make_engine(settings: PostgresSettings):
return create_engine(
f"postgresql+psycopg2://{settings.PG_USER}:{settings.PG_PASS}"
Expand All @@ -61,7 +91,7 @@ def write_df_to_postgres(df, settings: PostgresSettings):

def read_table_from_postgres(settings: PostgresSettings) -> pd.DataFrame:
engine = _make_engine(settings)
query = text(f'SELECT * FROM "{settings.DB_TABLE}" ORDER BY doc_id;')
query = text(f'SELECT * FROM "{settings.DB_TABLE}";')
return pd.read_sql(query, engine)


Expand Down Expand Up @@ -109,3 +139,32 @@ def lda_topic_modeling(settings):
# TODO: Use Spark Integration here
write_df_to_postgres(doc_topics, settings.doc_topic)
write_df_to_postgres(topic_terms, settings.topic_term)


@entrypoint(TopicExplanation)
def topic_explanation(settings):
logger.info("Starting topic explaination...")

logging.info("Querying topic terms from db...")
topic_terms = read_table_from_postgres(settings.topic_terms)

logging.info("Querying query information from db...")
query_information = read_table_from_postgres(settings.query_information)

metadata = query_information.iloc[0]

explainer = TopicExplainer(
model_name=settings.MODEL_NAME,
api_key=settings.OLLAMA_API_KEY
)

explainations = explainer.explain_topics(
topic_terms=topic_terms,
search_query=metadata["query"],
source=metadata["source"],
created_at=metadata["created_at"]
)

write_df_to_postgres(explainations, settings.explanations_output)

logging.info("Topic explanation block finished.")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pandas==2.3.2
SQLAlchemy==2.0.43
psycopg2-binary==2.9.10
pytest==9.0.1
ollama==0.6.1
175 changes: 175 additions & 0 deletions test/files/bike_norm.sql

Large diffs are not rendered by default.

Loading
Loading