From 2a9f40ed28b3ecbe7142a202b561c2d976820455 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Thu, 2 Nov 2023 20:46:02 +0100 Subject: [PATCH] Add input types to cypher templates (#12800) --- .../neo4j-cypher-ft/neo4j_cypher_ft/chain.py | 15 ++++++++++----- templates/neo4j-cypher/neo4j_cypher/chain.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py b/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py index 4868140b916..7d065315bc1 100644 --- a/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py +++ b/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py @@ -5,14 +5,10 @@ from langchain.chains.openai_functions import create_structured_output_chain from langchain.chat_models import ChatOpenAI from langchain.graphs import Neo4jGraph from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel, Field from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnablePassthrough -try: - from pydantic.v1.main import BaseModel, Field -except ImportError: - from pydantic.main import BaseModel, Field - # Connection to Neo4j graph = Neo4jGraph() @@ -127,3 +123,12 @@ chain = ( | qa_llm | StrOutputParser() ) + +# Add typing for input + + +class Question(BaseModel): + question: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/neo4j-cypher/neo4j_cypher/chain.py b/templates/neo4j-cypher/neo4j_cypher/chain.py index 730b2d2947d..16ca63f6d58 100644 --- a/templates/neo4j-cypher/neo4j_cypher/chain.py +++ b/templates/neo4j-cypher/neo4j_cypher/chain.py @@ -2,6 +2,7 @@ from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema from langchain.chat_models import ChatOpenAI from langchain.graphs import Neo4jGraph from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnablePassthrough @@ -71,3 +72,12 @@ chain = ( | qa_llm | StrOutputParser() ) + +# Add typing for input + + +class Question(BaseModel): + question: str + + +chain = chain.with_types(input_type=Question)