mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
templates: redis multi-modal multi-vector rag (#18946)
--------- Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
parent
915c1f8673
commit
239f0a615e
1
templates/rag-redis-multi-modal-multi-vector/.gitignore
vendored
Normal file
1
templates/rag-redis-multi-modal-multi-vector/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
21
templates/rag-redis-multi-modal-multi-vector/LICENSE
Normal file
21
templates/rag-redis-multi-modal-multi-vector/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 LangChain, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
119
templates/rag-redis-multi-modal-multi-vector/README.md
Normal file
119
templates/rag-redis-multi-modal-multi-vector/README.md
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
|
||||||
|
# rag-redis-multi-modal-multi-vector
|
||||||
|
|
||||||
|
Multi-modal LLMs enable visual assistants that can perform question-answering about images.
|
||||||
|
|
||||||
|
This template create a visual assistant for slide decks, which often contain visuals such as graphs or figures.
|
||||||
|
|
||||||
|
It uses GPT-4V to create image summaries for each slide, embeds the summaries, and stores them in Redis.
|
||||||
|
|
||||||
|
Given a question, relevant slides are retrieved and passed to GPT-4V for answer synthesis.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Input
|
||||||
|
|
||||||
|
Supply a slide deck as PDF in the `/docs` directory.
|
||||||
|
|
||||||
|
By default, this template has a slide deck about recent earnings from NVIDIA.
|
||||||
|
|
||||||
|
Example questions to ask can be:
|
||||||
|
```
|
||||||
|
1/ how much can H100 TensorRT improve LLama2 inference performance?
|
||||||
|
2/ what is the % change in GPU accelerated applications from 2020 to 2023?
|
||||||
|
```
|
||||||
|
|
||||||
|
To create an index of the slide deck, run:
|
||||||
|
```
|
||||||
|
poetry install
|
||||||
|
poetry shell
|
||||||
|
python ingest.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Storage
|
||||||
|
|
||||||
|
Here is the process the template will use to create an index of the slides (see [blog](https://blog.langchain.dev/multi-modal-rag-template/)):
|
||||||
|
|
||||||
|
* Extract the slides as a collection of images
|
||||||
|
* Use GPT-4V to summarize each image
|
||||||
|
* Embed the image summaries using text embeddings with a link to the original images
|
||||||
|
* Retrieve relevant image based on similarity between the image summary and the user input question
|
||||||
|
* Pass those images to GPT-4V for answer synthesis
|
||||||
|
|
||||||
|
### Redis
|
||||||
|
This template uses [Redis](https://redis.com) to power the [MultiVectorRetriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector) including:
|
||||||
|
- Redis as the [VectorStore](https://python.langchain.com/docs/integrations/vectorstores/redis) (to store + index image summary embeddings)
|
||||||
|
- Redis as the [ByteStore](https://python.langchain.com/docs/integrations/stores/redis) (to store images)
|
||||||
|
|
||||||
|
Make sure to deploy a Redis instance either in the [cloud](https://redis.com/try-free) (free) or locally with [docker](https://redis.io/docs/install/install-stack/docker/).
|
||||||
|
|
||||||
|
This will give you an accessible Redis endpoint that you can use as a URL. If deploying locally, simply use `redis://localhost:6379`.
|
||||||
|
|
||||||
|
|
||||||
|
## LLM
|
||||||
|
|
||||||
|
The app will retrieve images based on similarity between the text input and the image summary (text), and pass the images to GPT-4V for answer synthesis.
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
Set the `OPENAI_API_KEY` environment variable to access the OpenAI GPT-4V.
|
||||||
|
|
||||||
|
Set `REDIS_URL` environment variable to access your Redis database.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this package, you should first have the LangChain CLI installed:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install -U langchain-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
To create a new LangChain project and install this as the only package, you can do:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
langchain app new my-app --package rag-redis-multi-modal-multi-vector
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to add this to an existing project, you can just run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
langchain app add rag-redis-multi-modal-multi-vector
|
||||||
|
```
|
||||||
|
|
||||||
|
And add the following code to your `server.py` file:
|
||||||
|
```python
|
||||||
|
from rag_redis_multi_modal_multi_vector import chain as rag_redis_multi_modal_chain_mv
|
||||||
|
|
||||||
|
add_routes(app, rag_redis_multi_modal_chain_mv, path="/rag-redis-multi-modal-multi-vector")
|
||||||
|
```
|
||||||
|
|
||||||
|
(Optional) Let's now configure LangSmith.
|
||||||
|
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||||
|
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/).
|
||||||
|
If you don't have access, you can skip this section
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export LANGCHAIN_TRACING_V2=true
|
||||||
|
export LANGCHAIN_API_KEY=<your-api-key>
|
||||||
|
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are inside this directory, then you can spin up a LangServe instance directly by:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
langchain serve
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start the FastAPI app with a server is running locally at
|
||||||
|
[http://localhost:8000](http://localhost:8000)
|
||||||
|
|
||||||
|
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||||
|
We can access the playground at [http://127.0.0.1:8000/rag-redis-multi-modal-multi-vector/playground](http://127.0.0.1:8000/rag-redis-multi-modal-multi-vector/playground)
|
||||||
|
|
||||||
|
We can access the template from code with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langserve.client import RemoteRunnable
|
||||||
|
|
||||||
|
runnable = RemoteRunnable("http://localhost:8000/rag-redis-multi-modal-multi-vector")
|
||||||
|
```
|
Binary file not shown.
170
templates/rag-redis-multi-modal-multi-vector/ingest.py
Normal file
170
templates/rag-redis-multi-modal-multi-vector/ingest.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pypdfium2 as pdfium
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
|
from PIL import Image
|
||||||
|
from rag_redis_multi_modal_multi_vector.utils import ID_KEY, make_mv_retriever
|
||||||
|
|
||||||
|
|
||||||
|
def image_summarize(img_base64, prompt):
|
||||||
|
"""
|
||||||
|
Make image summary
|
||||||
|
|
||||||
|
:param img_base64: Base64 encoded string for image
|
||||||
|
:param prompt: Text prompt for summarizatiomn
|
||||||
|
:return: Image summarization prompt
|
||||||
|
|
||||||
|
"""
|
||||||
|
chat = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1024)
|
||||||
|
|
||||||
|
msg = chat.invoke(
|
||||||
|
[
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return msg.content
|
||||||
|
|
||||||
|
|
||||||
|
def generate_img_summaries(img_base64_list):
|
||||||
|
"""
|
||||||
|
Generate summaries for images
|
||||||
|
|
||||||
|
:param img_base64_list: Base64 encoded images
|
||||||
|
:return: List of image summaries and processed images
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Store image summaries
|
||||||
|
image_summaries = []
|
||||||
|
processed_images = []
|
||||||
|
|
||||||
|
# Prompt
|
||||||
|
prompt = """You are an assistant tasked with summarizing images for retrieval. \
|
||||||
|
These summaries will be embedded and used to retrieve the raw image. \
|
||||||
|
Give a concise summary of the image that is well optimized for retrieval."""
|
||||||
|
|
||||||
|
# Apply summarization to images
|
||||||
|
for i, base64_image in enumerate(img_base64_list):
|
||||||
|
try:
|
||||||
|
image_summaries.append(image_summarize(base64_image, prompt))
|
||||||
|
processed_images.append(base64_image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error with image {i+1}: {e}") # noqa: T201
|
||||||
|
|
||||||
|
return image_summaries, processed_images
|
||||||
|
|
||||||
|
|
||||||
|
def get_images_from_pdf(pdf_path):
|
||||||
|
"""
|
||||||
|
Extract images from each page of a PDF document and save as JPEG files.
|
||||||
|
|
||||||
|
:param pdf_path: A string representing the path to the PDF file.
|
||||||
|
"""
|
||||||
|
pdf = pdfium.PdfDocument(pdf_path)
|
||||||
|
n_pages = len(pdf)
|
||||||
|
pil_images = []
|
||||||
|
for page_number in range(n_pages):
|
||||||
|
page = pdf.get_page(page_number)
|
||||||
|
bitmap = page.render(scale=1, rotation=0, crop=(0, 0, 0, 0))
|
||||||
|
pil_image = bitmap.to_pil()
|
||||||
|
pil_images.append(pil_image)
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
|
||||||
|
def resize_base64_image(base64_string, size=(128, 128)):
|
||||||
|
"""
|
||||||
|
Resize an image encoded as a Base64 string
|
||||||
|
|
||||||
|
:param base64_string: Base64 string
|
||||||
|
:param size: Image size
|
||||||
|
:return: Re-sized Base64 string
|
||||||
|
"""
|
||||||
|
# Decode the Base64 string
|
||||||
|
img_data = base64.b64decode(base64_string)
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
|
# Resize the image
|
||||||
|
resized_img = img.resize(size, Image.LANCZOS)
|
||||||
|
|
||||||
|
# Save the resized image to a bytes buffer
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
resized_img.save(buffered, format=img.format)
|
||||||
|
|
||||||
|
# Encode the resized image to Base64
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_base64(pil_image):
|
||||||
|
"""
|
||||||
|
Convert PIL images to Base64 encoded strings
|
||||||
|
|
||||||
|
:param pil_image: PIL image
|
||||||
|
:return: Re-sized Base64 string
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffered = BytesIO()
|
||||||
|
pil_image.save(buffered, format="JPEG") # You can change the format if needed
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
img_str = resize_base64_image(img_str, size=(960, 540))
|
||||||
|
return img_str
|
||||||
|
|
||||||
|
|
||||||
|
def load_images(image_summaries, images):
|
||||||
|
"""
|
||||||
|
Index image summaries in the db.
|
||||||
|
|
||||||
|
:param image_summaries: Image summaries
|
||||||
|
:param images: Base64 encoded images
|
||||||
|
|
||||||
|
:return: Retriever
|
||||||
|
"""
|
||||||
|
|
||||||
|
retriever = make_mv_retriever()
|
||||||
|
|
||||||
|
# Helper function to add documents to the vectorstore and docstore
|
||||||
|
def add_documents(retriever, doc_summaries, doc_contents):
|
||||||
|
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
|
||||||
|
summary_docs = [
|
||||||
|
Document(page_content=s, metadata={ID_KEY: doc_ids[i]})
|
||||||
|
for i, s in enumerate(doc_summaries)
|
||||||
|
]
|
||||||
|
retriever.vectorstore.add_documents(summary_docs)
|
||||||
|
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
|
||||||
|
|
||||||
|
add_documents(retriever, image_summaries, images)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
doc_path = Path(__file__).parent / "docs/nvda-f3q24-investor-presentation-final.pdf"
|
||||||
|
rel_doc_path = doc_path.relative_to(Path.cwd())
|
||||||
|
|
||||||
|
print("Extract slides as images") # noqa: T201
|
||||||
|
pil_images = get_images_from_pdf(rel_doc_path)
|
||||||
|
|
||||||
|
# Convert to b64
|
||||||
|
images_base_64 = [convert_to_base64(i) for i in pil_images]
|
||||||
|
|
||||||
|
# Generate image summaries
|
||||||
|
print("Generate image summaries") # noqa: T201
|
||||||
|
image_summaries, images_base_64_processed = generate_img_summaries(images_base_64)
|
||||||
|
|
||||||
|
# Create documents
|
||||||
|
images_base_64_processed_documents = [
|
||||||
|
Document(page_content=i) for i in images_base_64_processed
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create retriever and load images
|
||||||
|
load_images(image_summaries, images_base_64_processed_documents)
|
2037
templates/rag-redis-multi-modal-multi-vector/poetry.lock
generated
Normal file
2037
templates/rag-redis-multi-modal-multi-vector/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
37
templates/rag-redis-multi-modal-multi-vector/pyproject.toml
Normal file
37
templates/rag-redis-multi-modal-multi-vector/pyproject.toml
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "rag-redis-multi-modal-multi-vector"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "Multi-modal RAG using Redis as the vectorstore and docstore"
|
||||||
|
authors = [
|
||||||
|
"Tyler Hutcherson <tyler.hutcherson@redis.com>"
|
||||||
|
]
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.8.1,<4.0"
|
||||||
|
langchain-core = ">=0.1.5"
|
||||||
|
langchain-openai = ">=0.0.1"
|
||||||
|
redis = "^5.0.1"
|
||||||
|
openai = "<=2.0.0"
|
||||||
|
pypdfium2 = "^4.27.0"
|
||||||
|
pillow = "^10.2.0"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
langchain-cli = ">=0.0.21"
|
||||||
|
fastapi = "^0.104.0"
|
||||||
|
sse-starlette = "^1.6.5"
|
||||||
|
|
||||||
|
[tool.langserve]
|
||||||
|
export_module = "rag_redis_multi_modal_multi_vector"
|
||||||
|
export_attr = "chain"
|
||||||
|
|
||||||
|
[tool.templates-hub]
|
||||||
|
use-case = "rag"
|
||||||
|
author = "Redis"
|
||||||
|
integrations = ["OpenAI", "Redis"]
|
||||||
|
tags = ["vectordbs"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
@ -0,0 +1,3 @@
|
|||||||
|
from rag_redis_multi_modal_multi_vector.chain import chain
|
||||||
|
|
||||||
|
__all__ = ["chain"]
|
@ -0,0 +1,109 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from rag_redis_multi_modal_multi_vector.utils import make_mv_retriever
|
||||||
|
|
||||||
|
|
||||||
|
def resize_base64_image(base64_string, size=(128, 128)):
|
||||||
|
"""
|
||||||
|
Resize an image encoded as a Base64 string.
|
||||||
|
|
||||||
|
:param base64_string: A Base64 encoded string of the image to be resized.
|
||||||
|
:param size: A tuple representing the new size (width, height) for the image.
|
||||||
|
:return: A Base64 encoded string of the resized image.
|
||||||
|
"""
|
||||||
|
img_data = base64.b64decode(base64_string)
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
resized_img = img.resize(size, Image.LANCZOS)
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
resized_img.save(buffered, format=img.format)
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def get_resized_images(docs):
|
||||||
|
"""
|
||||||
|
Resize images from base64-encoded strings.
|
||||||
|
|
||||||
|
:param docs: A list of base64-encoded image to be resized.
|
||||||
|
:return: Dict containing a list of resized base64-encoded strings.
|
||||||
|
"""
|
||||||
|
b64_images = []
|
||||||
|
for doc in docs:
|
||||||
|
if isinstance(doc, Document):
|
||||||
|
doc = doc.page_content
|
||||||
|
resized_image = resize_base64_image(doc, size=(1280, 720))
|
||||||
|
b64_images.append(resized_image)
|
||||||
|
return {"images": b64_images}
|
||||||
|
|
||||||
|
|
||||||
|
def img_prompt_func(data_dict, num_images=2):
|
||||||
|
"""
|
||||||
|
GPT-4V prompt for image analysis.
|
||||||
|
|
||||||
|
:param data_dict: A dict with images and a user-provided question.
|
||||||
|
:param num_images: Number of images to include in the prompt.
|
||||||
|
:return: A list containing message objects for each image and the text prompt.
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
if data_dict["context"]["images"]:
|
||||||
|
for image in data_dict["context"]["images"][:num_images]:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
text_message = {
|
||||||
|
"type": "text",
|
||||||
|
"text": (
|
||||||
|
"You are an analyst tasked with answering questions about visual content.\n"
|
||||||
|
"You will be give a set of image(s) from a slide deck / presentation.\n"
|
||||||
|
"Use this information to answer the user question. \n"
|
||||||
|
f"User-provided question: {data_dict['question']}\n\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
messages.append(text_message)
|
||||||
|
return [HumanMessage(content=messages)]
|
||||||
|
|
||||||
|
|
||||||
|
def multi_modal_rag_chain():
|
||||||
|
"""
|
||||||
|
Multi-modal RAG chain,
|
||||||
|
|
||||||
|
:return: A chain of functions representing the multi-modal RAG process.
|
||||||
|
"""
|
||||||
|
# Initialize the multi-modal Large Language Model with specific parameters
|
||||||
|
model = ChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024)
|
||||||
|
# Initialize the retriever
|
||||||
|
retriever = make_mv_retriever()
|
||||||
|
# Define the RAG pipeline
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"context": retriever | RunnableLambda(get_resized_images),
|
||||||
|
"question": RunnablePassthrough(),
|
||||||
|
}
|
||||||
|
| RunnableLambda(img_prompt_func)
|
||||||
|
| model
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create RAG chain
|
||||||
|
chain = multi_modal_rag_chain()
|
||||||
|
|
||||||
|
|
||||||
|
# Add typing for input
|
||||||
|
class Question(BaseModel):
|
||||||
|
__root__: str
|
||||||
|
|
||||||
|
|
||||||
|
chain = chain.with_types(input_type=Question)
|
@ -0,0 +1,10 @@
|
|||||||
|
text:
|
||||||
|
- name: content
|
||||||
|
tag:
|
||||||
|
- name: doc_id
|
||||||
|
vector:
|
||||||
|
- name: content_vector
|
||||||
|
algorithm: FLAT
|
||||||
|
datatype: FLOAT32
|
||||||
|
dims: 1536
|
||||||
|
distance_metric: COSINE
|
@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||||
|
from langchain_community.storage import RedisStore
|
||||||
|
from langchain_community.vectorstores import Redis as RedisVectorDB
|
||||||
|
from langchain_openai.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
ID_KEY = "doc_id"
|
||||||
|
|
||||||
|
|
||||||
|
def get_boolean_env_var(var_name, default_value=False):
|
||||||
|
"""Retrieve the boolean value of an environment variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var_name (str): The name of the environment variable to retrieve.
|
||||||
|
default_value (bool): The default value to return if the variable
|
||||||
|
is not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: The value of the environment variable, interpreted as a boolean.
|
||||||
|
"""
|
||||||
|
true_values = {"true", "1", "t", "y", "yes"}
|
||||||
|
false_values = {"false", "0", "f", "n", "no"}
|
||||||
|
|
||||||
|
# Retrieve the environment variable's value
|
||||||
|
value = os.getenv(var_name, "").lower()
|
||||||
|
|
||||||
|
# Decide the boolean value based on the content of the string
|
||||||
|
if value in true_values:
|
||||||
|
return True
|
||||||
|
elif value in false_values:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
|
# Check for openai API key
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise Exception("Must provide an OPENAI_API_KEY as an env var.")
|
||||||
|
|
||||||
|
|
||||||
|
def format_redis_conn_from_env() -> str:
|
||||||
|
redis_url = os.getenv("REDIS_URL", None)
|
||||||
|
if redis_url:
|
||||||
|
return redis_url
|
||||||
|
else:
|
||||||
|
using_ssl = get_boolean_env_var("REDIS_SSL", False)
|
||||||
|
start = "rediss://" if using_ssl else "redis://"
|
||||||
|
|
||||||
|
# if using RBAC
|
||||||
|
password = os.getenv("REDIS_PASSWORD", None)
|
||||||
|
username = os.getenv("REDIS_USERNAME", "default")
|
||||||
|
if password is not None:
|
||||||
|
start += f"{username}:{password}@"
|
||||||
|
|
||||||
|
host = os.getenv("REDIS_HOST", "localhost")
|
||||||
|
port = int(os.getenv("REDIS_PORT", 6379))
|
||||||
|
|
||||||
|
return start + f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
REDIS_URL = format_redis_conn_from_env()
|
||||||
|
|
||||||
|
current_file_path = os.path.abspath(__file__)
|
||||||
|
parent_dir = os.path.dirname(current_file_path)
|
||||||
|
schema_path = os.path.join(parent_dir, "schema.yml")
|
||||||
|
INDEX_SCHEMA = schema_path
|
||||||
|
|
||||||
|
|
||||||
|
def make_mv_retriever():
|
||||||
|
"""Create the multi-vector retriever"""
|
||||||
|
# Load Redis
|
||||||
|
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||||
|
vectorstore = RedisVectorDB(
|
||||||
|
redis_url=REDIS_URL,
|
||||||
|
index_name="image_summaries",
|
||||||
|
key_prefix="summary",
|
||||||
|
index_schema=INDEX_SCHEMA,
|
||||||
|
embedding=OpenAIEmbeddings(),
|
||||||
|
)
|
||||||
|
store = RedisStore(redis_url=REDIS_URL, namespace="image")
|
||||||
|
|
||||||
|
# Create the multi-vector retriever
|
||||||
|
return MultiVectorRetriever(
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
byte_store=store,
|
||||||
|
id_key=ID_KEY,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user