From 295b9b704b668f39f0417d27757e32201385566c Mon Sep 17 00:00:00 2001 From: Prashanth Rao <35005448+prrao87@users.noreply.github.com> Date: Tue, 16 Apr 2024 21:01:36 -0400 Subject: [PATCH] community[patch]: Improve Kuzu Cypher generation prompt (#20481) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - [x] **PR title**: "community: improve kuzu cypher generation prompt" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Improves the Kùzu Cypher generation prompt to be more robust to open source LLM outputs - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** @kuzudb - [x] **Add tests and docs**: If you're adding a new integration, please include No new tests (non-breaking. change) - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --- .../langchain/chains/graph_qa/kuzu.py | 28 +++++++++++++++++++ .../langchain/chains/graph_qa/prompts.py | 9 +++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chains/graph_qa/kuzu.py b/libs/langchain/langchain/chains/graph_qa/kuzu.py index 7df4cdc8465..61b044a8852 100644 --- a/libs/langchain/langchain/chains/graph_qa/kuzu.py +++ b/libs/langchain/langchain/chains/graph_qa/kuzu.py @@ -1,6 +1,7 @@ """Question answering over a graph.""" from __future__ import annotations +import re from typing import Any, Dict, List, Optional from langchain_community.graphs.kuzu_graph import KuzuGraph @@ -14,6 +15,30 @@ from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_ from langchain.chains.llm import LLMChain +def remove_prefix(text: str, prefix: str) -> str: + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from a text. + + Args: + text: Text to extract Cypher code from. + + Returns: + Cypher code extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + class KuzuQAChain(Chain): """Question-answering against a graph by generating Cypher statements for Kùzu. @@ -84,6 +109,9 @@ class KuzuQAChain(Chain): generated_cypher = self.cypher_generation_chain.run( {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks ) + # Extract Cypher code if it is wrapped in triple backticks + # with the language marker "cypher" + generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher") _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index d83aef9a622..a4b5db9583a 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -76,10 +76,11 @@ NGQL_GENERATION_PROMPT = PromptTemplate( KUZU_EXTRA_INSTRUCTIONS = """ Instructions: -Generate statement with Kùzu Cypher dialect (rather than standard): -1. do not use `WHERE EXISTS` clause to check the existence of a property because Kùzu database has a fixed schema. -2. do not omit relationship pattern. Always use `()-[]->()` instead of `()->()`. -3. do not include any notes or comments even if the statement does not produce the expected result. +Generate the Kùzu dialect of Cypher with the following rules in mind: + +1. Do not use a `WHERE EXISTS` clause to check the existence of a property. +2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. +3. Do not include any notes or comments even if the statement does not produce the expected result. ```\n""" KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(