Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion dockers/llm.rag.service/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import logging


logging.basicConfig(level=logging.DEBUG)


def format_context(results: List[Dict[str, Any]]) -> str:
"""Format search results into context for the LLM"""
context_parts = []
Expand Down Expand Up @@ -44,7 +47,22 @@ def trim_answer(generated_answer: str, label_separator: str) -> str:


def get_answer_with_settings(question, retriever, client, model_id, max_tokens, model_temperature, system_prompt):
docs = retriever.invoke(input=question)
search_params = {
"param": {
"metric_type": "L2",
"params": {"nprobe": 10},
},
"limit": 5,
"field_names": ["page_content", "metadata"],
"vector_field": ["dense", "sparse"],
"weights": [0.7, 0.2] # Weights for dense and sparse vectors
}

docs = retriever.get_relevant_documents(
query=question,
search_kwargs=search_params
)

num_of_docs = len(docs)
logging.info(f"Number of relevant documents retrieved and that will be used as context for query: {num_of_docs}")

Expand Down
139 changes: 139 additions & 0 deletions dockers/llm.rag.service/serveragllm_milvus_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "faiss-cpu",
# "fastapi",
# "langchain-community",
# "langchain-huggingface",
# "openai",
# "uvicorn",
# "weaviate",
# "langchain_milvus",
# "langchain-openai",
# "pymilvus"
# ]
# ///

import os
import sys
import uvicorn

from functools import partial
from typing import Union

import click
from fastapi import FastAPI
from openai import OpenAI

from common import get_answer_with_settings


SYSTEM_PROMPT="""You are a specialized support ticket assistant. Format your responses following these rules:
1. Answer the provided question only using the provided context.
2. Do not add the provided context to the generated answer.
3. Include relevant technical details when present or provide a summary of the comments in the ticket.
4. Include the submitter, assignee and collaborator for a ticket when this info is available.
5. If the question cannot be answered with the given context, please say so and do not attempt to provide an answer.
6. Do not create new questions related to the given question, instead answer only the provided question.
7. Provide a clear, direct and factual answer."""


def setup(
relevant_docs: int,
llm_server_url:str,
model_id: str,
max_tokens: int,
model_temperature: float,
):
app = FastAPI()

# TODO: move to imports
from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever
from langchain_milvus.function import (
BM25BuiltInFunction,
)
from langchain.embeddings import HuggingFaceEmbeddings
from pymilvus import connections, Collection, utility, WeightedRanker

# TODO: pass through settings or params
URI = "http://localhost:19530"
collection_name = "test_milvus_collection"
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)

connections.connect(
alias="default",
uri=URI
)

# Connect to the existing collection
collection = Collection(collection_name)
collection.load()

# from langchain_openai import OpenAIEmbeddings
# dense_embedding_func = OpenAIEmbeddings()

# Initialize the hybrid retriever with both vector fields
retriever = MilvusCollectionHybridSearchRetriever(
collection=collection,
content_field="page_content", # Field containing the document text
anns_fields=["dense", "sparse"], # Both vector fields
metadata_fields=["metadata"], # Include all metadata
field_embeddings=[embeddings, BM25BuiltInFunction()], # You might need to specify how sparse embeddings are handled
# Reranking configuration (optional but resolves validation)
rerank=WeightedRanker(0.5, 0.5), # or provide a reranking method if available
)

print("Created Vector DB retriever successfully. \n")

print("Creating an OpenAI client to the hosted model at URL: ", llm_server_url)
try:
client = OpenAI(base_url=llm_server_url, api_key="na")
except Exception as e:
print("Error creating client:", e)
sys.exit(1)

get_answer = partial(
get_answer_with_settings,
retriever=retriever,
client=client,
model_id=model_id,
max_tokens=max_tokens,
model_temperature=model_temperature,
system_prompt=SYSTEM_PROMPT,
)

@app.get("/answer/{question}")
def read_item(question: Union[str, None] = None):
print(f"Received question: {question}")
answer = get_answer(question)
return {"question": question, "answer": answer}

return app


MICROSOFT_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
MOSAICML_MODEL_ID = "mosaicml/mpt-7b-chat"
RELEVANT_DOCS_DEFAULT = 2
MAX_TOKENS_DEFAULT = 64
MODEL_TEMPERATURE_DEFAULT = 0.01

relevant_docs = os.getenv("RELEVANT_DOCS", RELEVANT_DOCS_DEFAULT)
llm_server_url = os.getenv("LLM_SERVER_URL", "http://localhost:11434/v1")
model_id = os.getenv("MODEL_ID", "llama2")
max_tokens = int(os.getenv("MAX_TOKENS", MAX_TOKENS_DEFAULT))
model_temperature = float(os.getenv("MODEL_TEMPERATURE", MODEL_TEMPERATURE_DEFAULT))

app = setup(relevant_docs, llm_server_url, model_id, max_tokens, model_temperature)


@click.command()
@click.option("--host", default="127.0.0.1", help="Host for the FastAPI server (default: 127.0.0.1)")
@click.option("--port", type=int, default=8000, help="Port for the FastAPI server (default: 8000)")
def run(host, port):
# Serve the app using Uvicorn
uvicorn.run("serveragllm_milvus_local:app", host=host, port=port, reload=True)


if __name__ == "__main__":
run()
6 changes: 5 additions & 1 deletion dockers/llm.vdb.service/.env_local_template
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ OUTPUT_FILENAME=/path/to/local/output_pickled.obj
# Vector DB Optional Settings
# EMBEDDING_CHUNK_SIZE=1000
# EMBEDDING_CHUNK_OVERLAP=100
# EMBEDDING_MODEL_NAME=sentence-transformers/all-MiniLM-L6-v2
# EMBEDDING_MODEL_NAME=sentence-transformers/all-MiniLM-L6-v2

# Milvus Vector DB Optional Settings
# MILVUS_URI="http://localhost:19530"
# MILVUS_COLLECTION_NAME="test_milvus_collection"
6 changes: 5 additions & 1 deletion dockers/llm.vdb.service/.env_s3_template
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ AWS_SECRET_ACCESS_KEY=my-secret-key
# Vector DB Optional Settings
# EMBEDDING_CHUNK_SIZE=1000
# EMBEDDING_CHUNK_OVERLAP=100
# EMBEDDING_MODEL_NAME=sentence-transformers/all-MiniLM-L6-v2
# EMBEDDING_MODEL_NAME=sentence-transformers/all-MiniLM-L6-v2

# Milvus Vector DB Optional Settings
# MILVUS_URI="http://localhost:19530"
# MILVUS_COLLECTION_NAME="test_milvus_collection"
38 changes: 38 additions & 0 deletions dockers/llm.vdb.service/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_milvus import BM25BuiltInFunction, Milvus


def load_jsonl_files_from_directory(directory):
Expand Down Expand Up @@ -77,3 +79,39 @@ def create_vectordb_from_data(
print("Convert to FAISS vectorstore")
vectorstore = FAISS.from_texts(texts, embeddings, metadatas=metadatas)
return vectorstore


def create_milvus_vectordb_from_data(
data,
embedding_model_name: str,
milvus_uri: str,
collection_name: str,
chunk_size,
chunk_overlap,
):
print("Start chunking documents")
texts, metadatas = chunk_documents_with_metadata(data, chunk_size, chunk_overlap)

docs = []
for text, metadata in zip(texts, metadatas):
docs.append(
Document(
page_content=text,
metadata=metadata,
)
)

embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
print("Convert to Milvus vectorstore")

vectorstore = Milvus(
embedding_function=embeddings,
vector_field=["dense", "sparse"],
builtin_function=BM25BuiltInFunction(),
collection_name=collection_name,
connection_args={"uri": milvus_uri},
auto_id=True
)

vectorstore.add_documents(documents=docs)
return vectorstore
18 changes: 18 additions & 0 deletions dockers/llm.vdb.service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ class S3Settings(BaseSettings):
description="Name of the embedding model to use"
)

milvus_uri: str = Field(
default="",
description="Milvus connection URI"
)
milvus_collection_name: str = Field(
default="",
description="Milvus collection name"
)

class Config:
env_file = ".env"

Expand All @@ -80,6 +89,15 @@ class LocalSettings(BaseSettings):
description="Name of the embedding model to use"
)

milvus_uri: str = Field(
default="",
description="Milvus connection URI"
)
milvus_collection_name: str = Field(
default="",
description="Milvus collection name"
)

class Config:
env_file = ".env"

Expand Down
16 changes: 13 additions & 3 deletions dockers/llm.vdb.service/createvectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import sys

from config import try_load_settings
from service import LocalDirDbCreationService, S3VectorDbCreationService
from service import (
LocalDirDbCreationService,
LocalDirMilvusDbCreationService,
S3VectorDbCreationService,
)


@click.command()
Expand All @@ -11,12 +15,18 @@ def run(env_file: str):
s3_settings, local_settings = try_load_settings(env_file)

if s3_settings:
if s3_settings.milvus_uri and s3_settings.milvus_collection_name:
raise "Missing config"
service = S3VectorDbCreationService(s3_settings)
service.create()

elif local_settings:
service = LocalDirDbCreationService(local_settings)
service.create()
if local_settings.milvus_uri and local_settings.milvus_collection_name:
service = LocalDirMilvusDbCreationService(local_settings)
service.create()
else:
service = LocalDirDbCreationService(local_settings)
service.create()

else:
# TODO: not really needed, error will be thrown earlier
Expand Down
21 changes: 21 additions & 0 deletions dockers/llm.vdb.service/createvectordb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pytest
import s3fs
import subprocess

from botocore.session import Session
from moto.moto_server.threaded_moto_server import ThreadedMotoServer
Expand All @@ -23,6 +24,26 @@ def test_create_faiss_vector_db_using_local_files():
os.remove("test_data/output/output_pickled.obj")


@pytest.fixture(scope="module")
def standalone_environment():
# Start the standalone environment before tests
try:
subprocess.run(["bash", "standalone_embed.sh", "start"], check=True)
yield
finally:
# Stop the standalone environment after tests, even if tests fail
subprocess.run(["bash", "standalone_embed.sh", "stop"], check=True)
subprocess.run(["bash", "standalone_embed.sh", "delete"], check=True)


def test_create_milvus_vector_db_using_local_files(standalone_environment):
ctx = click.Context(run)
try:
ctx.forward(run, env_file="test_data/.env_local_milvus")
except SystemExit as e:
assert e.code == 0


@pytest.fixture(scope="module")
def s3_base():
# writable local S3 system
Expand Down
Loading