From 7b836412b51477764809f63e59b0f322880da821 Mon Sep 17 00:00:00 2001 From: Bayu Siddhi Mukti Date: Fri, 1 Aug 2025 16:25:16 +0700 Subject: [PATCH] Add examples support to GraphCypherQAChain --- .../langchain_neo4j/chains/graph_qa/cypher.py | 3 + .../chains/graph_qa/prompts.py | 6 +- .../tests/unit_tests/chains/test_graph_qa.py | 75 +++++++++++++++++-- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index 9a5c200..4710b7a 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -121,6 +121,7 @@ class GraphCypherQAChain(Chain): graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + example_key: str = "examples" top_k: int = 10 """Number of results to return from the query""" return_intermediate_steps: bool = False @@ -324,8 +325,10 @@ def _call( _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() question = inputs[self.input_key] + examples = inputs.get(self.example_key, None) args = { "question": question, + "examples": examples, "schema": self.graph_schema, } args.update(inputs) diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/prompts.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/prompts.py index 9cac173..fd0a546 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/prompts.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/prompts.py @@ -11,11 +11,15 @@ Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. Do not include any text except the generated Cypher statement. +Examples (optional): +{examples} + The question is: {question}""" CYPHER_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE + input_variables=["schema", "examples", "question"], + template=CYPHER_GENERATION_TEMPLATE, ) CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index dc2117f..c55cc12 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -260,7 +260,7 @@ def test_chain_type() -> None: assert chain._chain_type == "graph_cypher_chain" -def test_graph_cypher_qa_chain() -> None: +def test_graph_cypher_qa_chain_without_examples() -> None: template = """You are a nice chatbot having a conversation with a human. Schema: @@ -269,11 +269,15 @@ def test_graph_cypher_qa_chain() -> None: Previous conversation: {chat_history} + Examples (optional): + {examples} + New human question: {question} Response:""" prompt = PromptTemplate( - input_variables=["schema", "question", "chat_history"], template=template + input_variables=["schema", "question", "examples", "chat_history"], + template=template, ) memory = ConversationBufferMemory(memory_key="chat_history") @@ -283,8 +287,8 @@ def test_graph_cypher_qa_chain() -> None: "Schema:\n Node properties:\n\nRelationship " "properties:\n\nThe relationships" ":\n\n\n " - "Previous conversation:\n \n\n New human question: " - "Test question\n Response:" + "Previous conversation:\n \n\n Examples (optional):\n" + " None\n\n New human question: Test question\n Response:" ) prompt2 = ( @@ -293,7 +297,8 @@ def test_graph_cypher_qa_chain() -> None: "properties:\n\nThe relationships" ":\n\n\n " "Previous conversation:\n Human: Test question\nAI: foo\n\n " - "New human question: Test new question\n Response:" + "Examples (optional):\n None\n\n New human question: " + "Test new question\n Response:" ) llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"}) @@ -314,6 +319,66 @@ def test_graph_cypher_qa_chain() -> None: assert True +def test_graph_cypher_qa_chain_with_examples() -> None: + template = """You are a nice chatbot having a conversation with a human. + + Schema: + {schema} + + Previous conversation: + {chat_history} + + Examples (optional): + {examples} + + New human question: {question} + Response:""" + + prompt = PromptTemplate( + input_variables=["schema", "question", "examples", "chat_history"], + template=template, + ) + + memory = ConversationBufferMemory(memory_key="chat_history", input_key="query") + readonlymemory = ReadOnlySharedMemory(memory=memory) + prompt1 = ( + "You are a nice chatbot having a conversation with a human.\n\n " + "Schema:\n Node properties:\n\nRelationship " + "properties:\n\nThe relationships" + ":\n\n\n " + "Previous conversation:\n \n\n Examples (optional):\n" + " Test examples\n\n New human question: " + "Test question\n Response:" + ) + + prompt2 = ( + "You are a nice chatbot having a conversation with a human.\n\n " + "Schema:\n Node properties:\n\nRelationship " + "properties:\n\nThe relationships" + ":\n\n\n " + "Previous conversation:\n Human: Test question\nAI: foo\n\n " + "Examples (optional):\n Test new examples\n\n " + "New human question: Test new question\n Response:" + ) + + llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"}) + chain = GraphCypherQAChain.from_llm( + cypher_llm=llm, + qa_llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + cypher_llm_kwargs={"prompt": prompt, "memory": readonlymemory}, + memory=memory, + allow_dangerous_requests=True, + ) + chain.run(query="Test question", examples="Test examples") + chain.run(query="Test new question", examples="Test new examples") + # If we get here without a key error, that means memory + # and examples were used properly to create prompts. + assert True + + def test_cypher_generation_failure() -> None: """Test the chain doesn't fail if the Cypher query fails to be generated.""" llm = FakeLLM(queries={"query": ""}, sequential_responses=True)