mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
Update SQL templates (#12464)
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from llamacpp.chain import chain
|
||||
from sql_llamacpp.chain import chain
|
||||
|
||||
__all__ = ["chain"]
|
@@ -6,6 +6,7 @@ import requests
|
||||
from langchain.llms import LlamaCpp
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
||||
from langchain.utilities import SQLDatabase
|
||||
@@ -113,9 +114,12 @@ prompt_response = ChatPromptTemplate.from_messages(
|
||||
("human", template),
|
||||
]
|
||||
)
|
||||
# Supply the input types to the prompt
|
||||
class InputType(BaseModel):
|
||||
question: str
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(query=sql_response_memory)
|
||||
RunnablePassthrough.assign(query=sql_response_memory).with_types(input_type=InputType)
|
||||
| RunnablePassthrough.assign(
|
||||
schema=get_schema,
|
||||
response=lambda x: db.run(x["query"]),
|
||||
|
Reference in New Issue
Block a user