Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class GraphCypherQAChain(Chain):
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
example_key: str = "examples"
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
Expand Down Expand Up @@ -324,8 +325,10 @@ def _call(
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
examples = inputs.get(self.example_key, None)
args = {
"question": question,
"examples": examples,
"schema": self.graph_schema,
}
args.update(inputs)
Expand Down
6 changes: 5 additions & 1 deletion libs/neo4j/langchain_neo4j/chains/graph_qa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.

Examples (optional):
{examples}

The question is:
{question}"""

CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
input_variables=["schema", "examples", "question"],
template=CYPHER_GENERATION_TEMPLATE,
)

CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
Expand Down
75 changes: 70 additions & 5 deletions libs/neo4j/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_chain_type() -> None:
assert chain._chain_type == "graph_cypher_chain"


def test_graph_cypher_qa_chain() -> None:
def test_graph_cypher_qa_chain_without_examples() -> None:
template = """You are a nice chatbot having a conversation with a human.

Schema:
Expand All @@ -269,11 +269,15 @@ def test_graph_cypher_qa_chain() -> None:
Previous conversation:
{chat_history}

Examples (optional):
{examples}

New human question: {question}
Response:"""

prompt = PromptTemplate(
input_variables=["schema", "question", "chat_history"], template=template
input_variables=["schema", "question", "examples", "chat_history"],
template=template,
)

memory = ConversationBufferMemory(memory_key="chat_history")
Expand All @@ -283,8 +287,8 @@ def test_graph_cypher_qa_chain() -> None:
"Schema:\n Node properties:\n\nRelationship "
"properties:\n\nThe relationships"
":\n\n\n "
"Previous conversation:\n \n\n New human question: "
"Test question\n Response:"
"Previous conversation:\n \n\n Examples (optional):\n"
" None\n\n New human question: Test question\n Response:"
)

prompt2 = (
Expand All @@ -293,7 +297,8 @@ def test_graph_cypher_qa_chain() -> None:
"properties:\n\nThe relationships"
":\n\n\n "
"Previous conversation:\n Human: Test question\nAI: foo\n\n "
"New human question: Test new question\n Response:"
"Examples (optional):\n None\n\n New human question: "
"Test new question\n Response:"
)

llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"})
Expand All @@ -314,6 +319,66 @@ def test_graph_cypher_qa_chain() -> None:
assert True


def test_graph_cypher_qa_chain_with_examples() -> None:
template = """You are a nice chatbot having a conversation with a human.

Schema:
{schema}

Previous conversation:
{chat_history}

Examples (optional):
{examples}

New human question: {question}
Response:"""

prompt = PromptTemplate(
input_variables=["schema", "question", "examples", "chat_history"],
template=template,
)

memory = ConversationBufferMemory(memory_key="chat_history", input_key="query")
readonlymemory = ReadOnlySharedMemory(memory=memory)
prompt1 = (
"You are a nice chatbot having a conversation with a human.\n\n "
"Schema:\n Node properties:\n\nRelationship "
"properties:\n\nThe relationships"
":\n\n\n "
"Previous conversation:\n \n\n Examples (optional):\n"
" Test examples\n\n New human question: "
"Test question\n Response:"
)

prompt2 = (
"You are a nice chatbot having a conversation with a human.\n\n "
"Schema:\n Node properties:\n\nRelationship "
"properties:\n\nThe relationships"
":\n\n\n "
"Previous conversation:\n Human: Test question\nAI: foo\n\n "
"Examples (optional):\n Test new examples\n\n "
"New human question: Test new question\n Response:"
)

llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"})
chain = GraphCypherQAChain.from_llm(
cypher_llm=llm,
qa_llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
cypher_llm_kwargs={"prompt": prompt, "memory": readonlymemory},
memory=memory,
allow_dangerous_requests=True,
)
chain.run(query="Test question", examples="Test examples")
chain.run(query="Test new question", examples="Test new examples")
# If we get here without a key error, that means memory
# and examples were used properly to create prompts.
assert True


def test_cypher_generation_failure() -> None:
"""Test the chain doesn't fail if the Cypher query fails to be generated."""
llm = FakeLLM(queries={"query": ""}, sequential_responses=True)
Expand Down