From f0eba1ac632e6ba8eb556002bf9973d205ede037 Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:13:44 -0700 Subject: [PATCH] Add RAG input types (#12684) Co-authored-by: Erick Friis --- templates/chat-bot-feedback/chat_bot_feedback/chain.py | 5 ++--- .../rag-matching-engine/rag_matching_engine/chain.py | 9 +++++++++ .../rag_pinecone_multi_query/chain.py | 9 +++++++++ .../rag-pinecone-rerank/rag_pinecone_rerank/chain.py | 9 +++++++++ templates/rag-pinecone/rag_pinecone/chain.py | 9 +++++++++ .../rag-semi-structured/rag_semi_structured/chain.py | 9 +++++++++ templates/rag-supabase/rag_supabase/chain.py | 9 +++++++++ templates/rag-weaviate/rag_weaviate/chain.py | 9 +++++++++ 8 files changed, 65 insertions(+), 3 deletions(-) diff --git a/templates/chat-bot-feedback/chat_bot_feedback/chain.py b/templates/chat-bot-feedback/chat_bot_feedback/chain.py index 42dec54a58c..7fe24d4d8ea 100644 --- a/templates/chat-bot-feedback/chat_bot_feedback/chain.py +++ b/templates/chat-bot-feedback/chat_bot_feedback/chain.py @@ -164,9 +164,8 @@ def format_chat_history(chain_input: dict) -> dict: # if you update the name of this, you MUST also update ../pyproject.toml # with the new `tool.langserve.export_attr` chain = ( - (format_chat_history | _prompt | _model | StrOutputParser()).with_types( - input_type=ChainInput - ) + (format_chat_history | _prompt | _model | StrOutputParser()) + .with_types(input_type=ChainInput) # This is to add the evaluators as "listeners" # and to customize the name of the chain. # Any chain that accepts a compatible input type works here. diff --git a/templates/rag-matching-engine/rag_matching_engine/chain.py b/templates/rag-matching-engine/rag_matching_engine/chain.py index e5ad87eae11..b2d215c0929 100644 --- a/templates/rag-matching-engine/rag_matching_engine/chain.py +++ b/templates/rag-matching-engine/rag_matching_engine/chain.py @@ -3,6 +3,7 @@ import os from langchain.embeddings import VertexAIEmbeddings from langchain.llms import VertexAI from langchain.prompts import PromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableParallel, RunnablePassthrough from langchain.vectorstores import MatchingEngine @@ -67,3 +68,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py b/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py index b676de6e9c0..5db9b5026ed 100644 --- a/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py +++ b/templates/rag-pinecone-multi-query/rag_pinecone_multi_query/chain.py @@ -3,6 +3,7 @@ import os from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableParallel, RunnablePassthrough @@ -55,3 +56,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py b/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py index 46171f4f165..ba1d37ea2c7 100644 --- a/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py +++ b/templates/rag-pinecone-rerank/rag_pinecone_rerank/chain.py @@ -3,6 +3,7 @@ import os from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CohereRerank from langchain.schema.output_parser import StrOutputParser @@ -62,3 +63,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-pinecone/rag_pinecone/chain.py b/templates/rag-pinecone/rag_pinecone/chain.py index 6777010d2a1..7e9dfb79666 100644 --- a/templates/rag-pinecone/rag_pinecone/chain.py +++ b/templates/rag-pinecone/rag_pinecone/chain.py @@ -3,6 +3,7 @@ import os from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableParallel, RunnablePassthrough from langchain.vectorstores import Pinecone @@ -50,3 +51,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-semi-structured/rag_semi_structured/chain.py b/templates/rag-semi-structured/rag_semi_structured/chain.py index e5e923ab55a..d0775ae318a 100644 --- a/templates/rag-semi-structured/rag_semi_structured/chain.py +++ b/templates/rag-semi-structured/rag_semi_structured/chain.py @@ -4,6 +4,7 @@ import uuid from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.retrievers.multi_vector import MultiVectorRetriever from langchain.schema.document import Document from langchain.schema.output_parser import StrOutputParser @@ -109,3 +110,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-supabase/rag_supabase/chain.py b/templates/rag-supabase/rag_supabase/chain.py index e116e840c12..1fead91b411 100644 --- a/templates/rag-supabase/rag_supabase/chain.py +++ b/templates/rag-supabase/rag_supabase/chain.py @@ -3,6 +3,7 @@ import os from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableParallel, RunnablePassthrough from langchain.vectorstores.supabase import SupabaseVectorStore @@ -39,3 +40,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question) diff --git a/templates/rag-weaviate/rag_weaviate/chain.py b/templates/rag-weaviate/rag_weaviate/chain.py index 3828ec9abae..8f128fbf8aa 100644 --- a/templates/rag-weaviate/rag_weaviate/chain.py +++ b/templates/rag-weaviate/rag_weaviate/chain.py @@ -4,6 +4,7 @@ from langchain.chat_models import ChatOpenAI from langchain.document_loaders import WebBaseLoader from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableParallel, RunnablePassthrough from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -50,3 +51,11 @@ chain = ( | model | StrOutputParser() ) + + +# Add typing for input +class Question(BaseModel): + __root__: str + + +chain = chain.with_types(input_type=Question)