|
| 1 | +from typing import Any, Dict, List |
| 2 | + |
| 3 | +from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema |
| 4 | +from langchain.chat_models import ChatOpenAI |
| 5 | +from langchain.graphs import Neo4jGraph |
| 6 | +from langchain.memory import ChatMessageHistory |
| 7 | +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 8 | +from langchain.pydantic_v1 import BaseModel |
| 9 | +from langchain.schema.output_parser import StrOutputParser |
| 10 | +from langchain.schema.runnable import RunnablePassthrough |
| 11 | + |
| 12 | +# Connection to Neo4j |
| 13 | +graph = Neo4jGraph() |
| 14 | + |
| 15 | +# Cypher validation tool for relationship directions |
| 16 | +corrector_schema = [ |
| 17 | + Schema(el["start"], el["type"], el["end"]) |
| 18 | + for el in graph.structured_schema.get("relationships") |
| 19 | +] |
| 20 | +cypher_validation = CypherQueryCorrector(corrector_schema) |
| 21 | + |
| 22 | +# LLMs |
| 23 | +cypher_llm = ChatOpenAI(model_name="gpt-4", temperature=0.0) |
| 24 | +qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0) |
| 25 | + |
| 26 | + |
| 27 | +def convert_messages(input: List[Dict[str, Any]]) -> ChatMessageHistory: |
| 28 | + history = ChatMessageHistory() |
| 29 | + for item in input: |
| 30 | + history.add_user_message(item["result"]["question"]) |
| 31 | + history.add_ai_message(item["result"]["answer"]) |
| 32 | + return history |
| 33 | + |
| 34 | + |
| 35 | +def get_history(input: Dict[str, Any]) -> ChatMessageHistory: |
| 36 | + input.pop("question") |
| 37 | + # Lookback conversation window |
| 38 | + window = 3 |
| 39 | + data = graph.query( |
| 40 | + """ |
| 41 | + MATCH (u:User {id:$user_id})-[:HAS_SESSION]->(s:Session {id:$session_id}), |
| 42 | + (s)-[:LAST_MESSAGE]->(last_message) |
| 43 | + MATCH p=(last_message)<-[:NEXT*0..""" |
| 44 | + + str(window) |
| 45 | + + """]-() |
| 46 | + WITH p, length(p) AS length |
| 47 | + ORDER BY length DESC LIMIT 1 |
| 48 | + UNWIND reverse(nodes(p)) AS node |
| 49 | + MATCH (node)-[:HAS_ANSWER]->(answer) |
| 50 | + RETURN {question:node.text, answer:answer.text} AS result |
| 51 | + """, |
| 52 | + params=input, |
| 53 | + ) |
| 54 | + history = convert_messages(data) |
| 55 | + return history.messages |
| 56 | + |
| 57 | + |
| 58 | +def save_history(input): |
| 59 | + input.pop("response") |
| 60 | + # store history to database |
| 61 | + graph.query( |
| 62 | + """MERGE (u:User {id: $user_id}) |
| 63 | +WITH u |
| 64 | +OPTIONAL MATCH (u)-[:HAS_SESSION]->(s:Session{id: $session_id}), |
| 65 | + (s)-[l:LAST_MESSAGE]->(last_message) |
| 66 | +FOREACH (_ IN CASE WHEN last_message IS NULL THEN [1] ELSE [] END | |
| 67 | +CREATE (u)-[:HAS_SESSION]->(s1:Session {id:$session_id}), |
| 68 | + (s1)-[:LAST_MESSAGE]->(q:Question {text:$question, cypher:$query, date:datetime()}), |
| 69 | + (q)-[:HAS_ANSWER]->(:Answer {text:$output})) |
| 70 | +FOREACH (_ IN CASE WHEN last_message IS NOT NULL THEN [1] ELSE [] END | |
| 71 | +CREATE (last_message)-[:NEXT]->(q:Question |
| 72 | + {text:$question, cypher:$query, date:datetime()}), |
| 73 | + (q)-[:HAS_ANSWER]->(:Answer {text:$output}), |
| 74 | + (s)-[:LAST_MESSAGE]->(q) |
| 75 | +DELETE l) """, |
| 76 | + params=input, |
| 77 | + ) |
| 78 | + |
| 79 | + # Return LLM response to the chain |
| 80 | + return input["output"] |
| 81 | + |
| 82 | + |
| 83 | +# Generate Cypher statement based on natural language input |
| 84 | +cypher_template = """This is important for my career. |
| 85 | +Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: |
| 86 | +{schema} |
| 87 | +
|
| 88 | +Question: {question} |
| 89 | +Cypher query:""" # noqa: E501 |
| 90 | + |
| 91 | +cypher_prompt = ChatPromptTemplate.from_messages( |
| 92 | + [ |
| 93 | + ( |
| 94 | + "system", |
| 95 | + "Given an input question, convert it to a Cypher query. No pre-amble.", |
| 96 | + ), |
| 97 | + MessagesPlaceholder(variable_name="history"), |
| 98 | + ("human", cypher_template), |
| 99 | + ] |
| 100 | +) |
| 101 | + |
| 102 | +cypher_response = ( |
| 103 | + RunnablePassthrough.assign(schema=lambda _: graph.get_schema, history=get_history) |
| 104 | + | cypher_prompt |
| 105 | + | cypher_llm.bind(stop=["\nCypherResult:"]) |
| 106 | + | StrOutputParser() |
| 107 | +) |
| 108 | + |
| 109 | +# Generate natural language response based on database results |
| 110 | +response_template = """Based on the the question, Cypher query, and Cypher response, write a natural language response: |
| 111 | +Question: {question} |
| 112 | +Cypher query: {query} |
| 113 | +Cypher Response: {response}""" # noqa: E501 |
| 114 | + |
| 115 | +response_prompt = ChatPromptTemplate.from_messages( |
| 116 | + [ |
| 117 | + ( |
| 118 | + "system", |
| 119 | + "Given an input question and Cypher response, convert it to a " |
| 120 | + "natural language answer. No pre-amble.", |
| 121 | + ), |
| 122 | + ("human", response_template), |
| 123 | + ] |
| 124 | +) |
| 125 | + |
| 126 | +chain = ( |
| 127 | + RunnablePassthrough.assign(query=cypher_response) |
| 128 | + | RunnablePassthrough.assign( |
| 129 | + response=lambda x: graph.query(cypher_validation(x["query"])), |
| 130 | + ) |
| 131 | + | RunnablePassthrough.assign( |
| 132 | + output=response_prompt | qa_llm | StrOutputParser(), |
| 133 | + ) |
| 134 | + | save_history |
| 135 | +) |
| 136 | + |
| 137 | +# Add typing for input |
| 138 | + |
| 139 | + |
| 140 | +class Question(BaseModel): |
| 141 | + question: str |
| 142 | + user_id: str |
| 143 | + session_id: str |
| 144 | + |
| 145 | + |
| 146 | +chain = chain.with_types(input_type=Question) |
0 commit comments