Add input types to cypher templates (#12800)

This commit is contained in:
Tomaz Bratanic 2023-11-02 20:46:02 +01:00 committed by GitHub
parent c4fdf78d03
commit 2a9f40ed28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 5 deletions

View File

@ -5,14 +5,10 @@ from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
try:
from pydantic.v1.main import BaseModel, Field
except ImportError:
from pydantic.main import BaseModel, Field
# Connection to Neo4j
graph = Neo4jGraph()
@ -127,3 +123,12 @@ chain = (
| qa_llm
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
question: str
chain = chain.with_types(input_type=Question)

View File

@ -2,6 +2,7 @@ from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
@ -71,3 +72,12 @@ chain = (
| qa_llm
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
question: str
chain = chain.with_types(input_type=Question)