mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
 | 
						|
from langchain_community.chat_models import ChatOpenAI
 | 
						|
from langchain_community.document_loaders import PyPDFLoader
 | 
						|
from langchain_community.embeddings import OpenAIEmbeddings
 | 
						|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
 | 
						|
from langchain_core.output_parsers import StrOutputParser
 | 
						|
from langchain_core.prompts import ChatPromptTemplate
 | 
						|
from langchain_core.pydantic_v1 import BaseModel
 | 
						|
from langchain_core.runnables import (
 | 
						|
    RunnableLambda,
 | 
						|
    RunnableParallel,
 | 
						|
    RunnablePassthrough,
 | 
						|
)
 | 
						|
from langchain_text_splitters import RecursiveCharacterTextSplitter
 | 
						|
from pymongo import MongoClient
 | 
						|
 | 
						|
# Set DB
 | 
						|
if os.environ.get("MONGO_URI", None) is None:
 | 
						|
    raise Exception("Missing `MONGO_URI` environment variable.")
 | 
						|
MONGO_URI = os.environ["MONGO_URI"]
 | 
						|
 | 
						|
DB_NAME = "langchain-test-2"
 | 
						|
COLLECTION_NAME = "test"
 | 
						|
ATLAS_VECTOR_SEARCH_INDEX_NAME = "default"
 | 
						|
 | 
						|
client = MongoClient(MONGO_URI)
 | 
						|
db = client[DB_NAME]
 | 
						|
MONGODB_COLLECTION = db[COLLECTION_NAME]
 | 
						|
 | 
						|
# Read from MongoDB Atlas Vector Search
 | 
						|
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
 | 
						|
    MONGO_URI,
 | 
						|
    DB_NAME + "." + COLLECTION_NAME,
 | 
						|
    OpenAIEmbeddings(disallowed_special=()),
 | 
						|
    index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
 | 
						|
)
 | 
						|
retriever = vectorstore.as_retriever()
 | 
						|
 | 
						|
# RAG prompt
 | 
						|
template = """Answer the question based only on the following context:
 | 
						|
{context}
 | 
						|
Question: {question}
 | 
						|
"""
 | 
						|
prompt = ChatPromptTemplate.from_template(template)
 | 
						|
 | 
						|
# RAG
 | 
						|
model = ChatOpenAI()
 | 
						|
chain = (
 | 
						|
    RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
 | 
						|
    | prompt
 | 
						|
    | model
 | 
						|
    | StrOutputParser()
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
# Add typing for input
 | 
						|
class Question(BaseModel):
 | 
						|
    __root__: str
 | 
						|
 | 
						|
 | 
						|
chain = chain.with_types(input_type=Question)
 | 
						|
 | 
						|
 | 
						|
def _ingest(url: str) -> dict:
 | 
						|
    loader = PyPDFLoader(url)
 | 
						|
    data = loader.load()
 | 
						|
 | 
						|
    # Split docs
 | 
						|
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
 | 
						|
    docs = text_splitter.split_documents(data)
 | 
						|
 | 
						|
    # Insert the documents in MongoDB Atlas Vector Search
 | 
						|
    _ = MongoDBAtlasVectorSearch.from_documents(
 | 
						|
        documents=docs,
 | 
						|
        embedding=OpenAIEmbeddings(disallowed_special=()),
 | 
						|
        collection=MONGODB_COLLECTION,
 | 
						|
        index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
 | 
						|
    )
 | 
						|
    return {}
 | 
						|
 | 
						|
 | 
						|
ingest = RunnableLambda(_ingest)
 |