Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/workspaces/cookiefactoryv3/assistant/.gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.langchain.db

# Dockerfile
.env
.python-version

test_run.sh
*.swp
package-lock.json
__pycache__
Expand Down
31 changes: 31 additions & 0 deletions src/workspaces/cookiefactoryv3/assistant/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
FROM --platform=linux/amd64 node:20-slim
RUN npm install -g pnpm

# Install Python and other dependencies
RUN apt-get update && apt-get install -y \
python3 \
python3-pip \
git \
curl \
unzip \
jq

# Install AWS CLI
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \
unzip awscliv2.zip && \
./aws/install && \
rm -rf awscliv2.zip aws


# Copy your application files
WORKDIR /app
COPY . .

# Install dependencies using your install.sh script
RUN sh install.sh

# Expose the port Chainlit will listen on
EXPOSE 8000

# Start the Chainlit app using your run.sh script
CMD ["sh", "run.sh"]
33 changes: 21 additions & 12 deletions src/workspaces/cookiefactoryv3/assistant/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@
from chainlit.context import context

from lib.router import LLMRouterChain, MultiRouteChain, create_routes
from lib.llm import get_bedrock_text
from lib.llm import get_bedrock_text_v3_sonnet
from lib.context_memory import EntityContextMemory
from lib.tools.qa import init_db
from lib.initial_diagnosis import InitialDiagnosisChain

print('initializing vector db')

init_db(os.path.join(os.path.dirname(__file__), '..', 'public', './freezer_tunnel_manual.pdf'))

print('done initializing vector db')

## To be implemented with update to chainlit
#from chainlit.oauth_providers import AWSCognitoOAuthProvider

welcome_message="""
Hi, I'm the AI assistant of the Cookie Factory. I'm here to help you diagnose and resolve issues \
Expand All @@ -41,15 +36,28 @@
There is an ongoing event [#{event_id}](https://example.com/issue/{event_id}). Do you want to run an initial diagnosis of the issue?
"""

## To be implemented with update to chainlit
#cl.oauth_providers = [
# AWSCognitoOAuthProvider(
# client_id=os.environ["OAUTH_COGNITO_CLIENT_ID"],
# client_secret=os.environ["OAUTH_COGNITO_CLIENT_SECRET"],
# domain=os.environ["OAUTH_COGNITO_DOMAIN"],
# )
#]

@cl.on_chat_start
async def start():

## To be implemented with udpdate to chainlit
#if not cl.user_session.get("is_authenticated"):
# await cl.oauth_providers[0].authorize()

LLMRouterChain.update_forward_refs()
MultiRouteChain.update_forward_refs()

memory = EntityContextMemory()
routes = create_routes(memory)
llm_chain = MultiRouteChain.from_prompts(llm=get_bedrock_text(), prompt_infos=routes)
llm_chain = MultiRouteChain.from_prompts(llm=get_bedrock_text_v3_sonnet(), prompt_infos=routes)

cl.user_session.set("chain", llm_chain)

Expand All @@ -71,7 +79,7 @@ async def start():
async def main(message, context):
llm_chain = cl.user_session.get("chain")

res = await llm_chain.acall(
res = await llm_chain.ainvoke(
message,
callbacks=[cl.AsyncLangchainCallbackHandler()])

Expand All @@ -88,9 +96,9 @@ async def on_action(action):
await action.remove()

cb = cl.AsyncLangchainCallbackHandler()
chain = InitialDiagnosisChain.from_llm(llm=get_bedrock_text())
chain = InitialDiagnosisChain.from_llm(llm=get_bedrock_text_v3_sonnet())

res = await chain.acall({
res = await chain.ainvoke({
'event_title': event_title,
'event_description': event_description,
'event_timestamp': event_timestamp
Expand All @@ -105,6 +113,7 @@ async def on_action(action):
@cl.action_callback("agent_actions")
async def on_action(action):
event_entity_id = context.session.user_data.get('event_entity_id')

if action.value == "3d":
await cl.Message(content=f"Navigating to the issue site.").send()
await action.remove()
Expand Down
3 changes: 3 additions & 0 deletions src/workspaces/cookiefactoryv3/assistant/app/lib/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ def get_bedrock_region():

def get_workspace_id():
return os.getenv("WORKSPACE_ID")

def get_knowledge_base_id():
return os.getenv("KNOWLEDGE_BASE_ID")
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, Dict, List, Optional

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain

from langchain.callbacks.manager import (
Expand All @@ -15,15 +15,15 @@
from langchain.chains.base import Chain
from langchain.schema.language_model import BaseLanguageModel

from .llm import get_bedrock_text, get_processed_prompt_template
from .llm import get_bedrock_text_v3_sonnet, get_processed_prompt_template_sonnet

default_llm = get_bedrock_text()
default_llm = get_bedrock_text_v3_sonnet()

question_classifier_prompt = """
You are a technical assistant to help the cooke line operators to investigate product quality issues. \
Your task is take the "Collected Information" from alarm systems, summarize the issue and provide prescriptive suggestions as "Initial Diagnosis" based on \
your knowledge about cookie production to provide initial suggestions for line operators to investigate the issue. Be concise and professional in the response. \
Translate technical terms to business terms so it's easy for line operators to read and understand, for example, timestamps should be converted to local user friendly format.
Translate technical terms to business terms so it's easy for line operators to read and understand, for example, timestamps should be converted to local user friendly format and use today's date instead of the one provided for October 23rd 2023.

<example>
Collected information
Expand Down Expand Up @@ -95,7 +95,7 @@ def from_llm(
) -> InitialDiagnosisChain:
router_template = question_classifier_prompt
router_prompt = PromptTemplate(
template=get_processed_prompt_template(router_template),
template=get_processed_prompt_template_sonnet(router_template),
input_variables=["event_title", "event_description", "event_timestamp"],
)
llm_chain = LLMChain(llm=llm, prompt=router_prompt)
Expand Down
76 changes: 64 additions & 12 deletions src/workspaces/cookiefactoryv3/assistant/app/lib/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,19 @@

import boto3

from langchain.llms.bedrock import Bedrock
#from langchain.llms.bedrock import Bedrock
from langchain_community.chat_models import BedrockChat
from langchain.embeddings.bedrock import BedrockEmbeddings

from botocore.config import Config

from .env import get_bedrock_region

available_models = [
"amazon.titan-tg1-large",
"anthropic.claude-v2",
"ai21.j2-ultra",
"ai21.j2-mid",
"anthropic.claude-instant-v1",
"anthropic.claude-v1"
]

# the current model used for text generation
text_model_id = "anthropic.claude-instant-v1"
text_v3_haiku_model_id = "anthropic.claude-3-haiku-20240307-v1:0"
text_v2_model_id = "anthropic.claude-v2"
text_v3_sonnet_model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
embedding_model_id = "amazon.titan-embed-text-v1"

model_kwargs = {
Expand All @@ -34,6 +28,22 @@
"temperature": 0.1,
"top_p": 0.9,
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 2048,
"temperature": 0.0,
"top_k": 250,
"top_p": 1,
"stop_sequences": ["\n\nHuman"]
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 2048,
"temperature": 0.0,
"top_k": 250,
"top_p": 1,
"stop_sequences": ["\n\nHuman"]
},
"anthropic.claude-instant-v1": {
"max_tokens_to_sample": 2048,
"temperature": 0.1,
Expand All @@ -42,11 +52,15 @@
}

prompt_template_prefix = {
"anthropic.claude-3-haiku-20240307-v1:0": "\n\nHuman: ",
"anthropic.claude-3-sonnet-20240229-v1:0": "\n\nHuman: ",
"anthropic.claude-v2": "\n\nHuman: ",
"anthropic.claude-instant-v1": "\n\nHuman: "
}

prompt_template_postfix = {
"anthropic.claude-3-haiku-20240307-v1:0": "\n\nAssistant: ",
"anthropic.claude-3-sonnet-20240229-v1:0": "\n\nAssistant:",
"anthropic.claude-v2": "\n\nAssistant:",
"anthropic.claude-instant-v1": "\n\nAssistant:"
}
Expand All @@ -57,6 +71,8 @@ def template_proc(template):
return template_proc

prompt_template_procs = {
"anthropic.claude-3-haiku-20240307-v1:0": get_template_proc("anthropic.claude-3-haiku-20240307-v1:0"),
"anthropic.claude-3-sonnet-20240229-v1:0": get_template_proc("anthropic.claude-3-sonnet-20240229-v1:0"),
"anthropic.claude-v2": get_template_proc("anthropic.claude-v2"),
"anthropic.claude-instant-v1": get_template_proc("anthropic.claude-instant-v1")
}
Expand All @@ -73,16 +89,34 @@ def template_proc(template):
'mode': 'standard'
}
))

bedrock_agents = boto3.client('bedrock-agent-runtime', get_bedrock_region(), config=Config(
retries = {
'max_attempts': 10,
'mode': 'standard'
}
))

response = bedrock.list_foundation_models()
print(response.get('modelSummaries'))

def get_bedrock_text():
llm = Bedrock(model_id=text_model_id, client=bedrock_runtime)
llm = BedrockChat(model_id=text_model_id, streaming=True, client=bedrock_runtime)
llm.model_kwargs = model_kwargs.get(text_model_id, {})
return llm

def get_bedrock_text_v3_haiku():
llm = BedrockChat(model_id=text_v3_haiku_model_id, streaming=True, client=bedrock_runtime)
llm.model_kwargs = model_kwargs.get(text_v3_haiku_model_id, {})
return llm

def get_bedrock_text_v3_sonnet():
llm = BedrockChat(model_id=text_v3_sonnet_model_id, streaming=True, client=bedrock_runtime)
llm.model_kwargs = model_kwargs.get(text_v3_sonnet_model_id, {})
return llm

def get_bedrock_text_v2():
llm = Bedrock(model_id=text_v2_model_id, client=bedrock_runtime)
llm = BedrockChat(model_id=text_v2_model_id, streaming=True, client=bedrock_runtime)
llm.model_kwargs = model_kwargs.get(text_v2_model_id, {})
return llm

Expand Down Expand Up @@ -110,3 +144,21 @@ def get_postfix_prompt_template(template):
return template + prompt_template_postfix[text_model_id]
else:
return template

def get_processed_prompt_template_sonnet(template):
if text_v3_sonnet_model_id in prompt_template_procs:
return prompt_template_procs[text_v3_sonnet_model_id](template)
else:
return template

def get_prefix_prompt_template_sonnet(template):
if text_v3_sonnet_model_id in prompt_template_prefix:
return prompt_template_prefix[text_v3_sonnet_model_id] + template
else:
return template

def get_postfix_prompt_template_sonnet(template):
if text_v3_sonnet_model_id in prompt_template_postfix:
return template + prompt_template_postfix[text_v3_sonnet_model_id]
else:
return template
17 changes: 13 additions & 4 deletions src/workspaces/cookiefactoryv3/assistant/app/lib/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain

from langchain.callbacks.manager import (
Expand All @@ -21,10 +21,11 @@
from .tools.qa import QAChain
from .tools.graph import GraphChain
from .tools.general import GeneralChain
from .tools.inspect_sensor_data import InspectChain

from .llm import get_bedrock_text, get_processed_prompt_template
from .llm import get_bedrock_text_v3_sonnet, get_processed_prompt_template_sonnet

default_llm = get_bedrock_text()
default_llm = get_bedrock_text_v3_sonnet()

question_classifier_prompt = """
You are given an instruction. The instruction is either a command or a question. You need to decide the type of the instruction provided by user.
Expand All @@ -34,6 +35,7 @@
- 3dview: if the instruction is a command about manipulating the 3D viewer
- doc: if the instruction is a question about standard procedures in the knowledge base
- graph: if the instruction is a question about finding information of the entities in the factory
- inspect: if the instruction is a question about analyzing historical data for a freezer
- general: if none of the above applies

You must give an answer using one of the valid options, and you should write out the answer without further explanation.
Expand All @@ -46,6 +48,9 @@
Instruction: how to operate the cookie line?
Answer: doc

Instruction: Can you analyze the last 15 minutes of data for the freezer tunnel for any potential issues?
Answer: inspect

Instruction: what are the potential causes of the inconsistent cookie shape?
Answer: general

Expand Down Expand Up @@ -134,7 +139,7 @@ def from_llm(
) -> LLMRouterChain:
router_template = question_classifier_prompt
router_prompt = PromptTemplate(
template=get_processed_prompt_template(router_template),
template=get_processed_prompt_template_sonnet(router_template),
input_variables=["question"],
)
llm_chain = LLMChain(llm=llm, prompt=router_prompt)
Expand Down Expand Up @@ -224,6 +229,10 @@ def create_routes(memory):
"name": "general",
"chain": GeneralChain()
},
{
"name": "inspect",
"chain": InspectChain()
},
{
"name": "3dview",
"chain": ViewChain.create(memory=memory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

from typing import Any, Dict, List, Optional

from langchain import LLMChain, PromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain import LLMChain
from langchain.agents import tool
from langchain.chains.base import Chain
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)

from ..llm import get_bedrock_text, get_processed_prompt_template
from ..llm import get_bedrock_text_v3_sonnet, get_processed_prompt_template_sonnet

llm = get_bedrock_text()
prompt_template = get_processed_prompt_template("{question}")
llm = get_bedrock_text_v3_sonnet()
prompt_template = get_processed_prompt_template_sonnet("{question}")

def get_tool_metadata():
return {
Expand Down
Loading