mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 00:00:34 +00:00
Compare commits
8 Commits
sr/system-
...
(vectorsto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94b36f8340 | ||
|
|
501701a396 | ||
|
|
c342b45589 | ||
|
|
b254d08830 | ||
|
|
242d055b12 | ||
|
|
13b9db6d2d | ||
|
|
33252fa352 | ||
|
|
eb95a3917a |
363
docs/docs/integrations/vectorstores/pgvector_async.ipynb
Normal file
363
docs/docs/integrations/vectorstores/pgvector_async.ipynb
Normal file
@@ -0,0 +1,363 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PGVector with async connections\n",
|
||||
"\n",
|
||||
">[PGVector](https://github.com/pgvector/pgvector) is an open-source vector similarity search for `Postgres`\n",
|
||||
"\n",
|
||||
"It supports:\n",
|
||||
"- exact and approximate nearest neighbor search\n",
|
||||
"- L2 distance, inner product, and cosine distance\n",
|
||||
"\n",
|
||||
"This notebook shows how to use the Postgres vector database (`PGVector`) with async connections."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"See the [installation instruction](https://github.com/pgvector/pgvector)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Pip install necessary package\n",
|
||||
"!pip install pgvector\n",
|
||||
"!pip install openai\n",
|
||||
"!pip install asyncpg\n",
|
||||
"!pip install greenlet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Loading Environment Variables\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"\n",
|
||||
"load_dotenv()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores.pgvector_async import PGVectorAsync\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.docstore.document import Document"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# PGVectorAsync need the database url to connect to the database.\n",
|
||||
"\n",
|
||||
"DATABASE_URL = \"postgresql+asyncpg://postgres:postgres@localhost:5432/postgres\"\n",
|
||||
"\n",
|
||||
"# Alternatively, you can pass a async engine to PGVectorAsync\n",
|
||||
"# engine = create_async_engine(url=DATABASE_URL, echo=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set up your database\n",
|
||||
"\n",
|
||||
"You only need to run this once, preferably in a migration script."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore = PGVectorAsync(\n",
|
||||
" embeddings=embeddings,\n",
|
||||
" db_url=DATABASE_URL,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Alternatively, you can pass a async engine to PGVectorAsync\n",
|
||||
"# vectorstore = PGVectorAsync(\n",
|
||||
"# embeddings=embeddings,\n",
|
||||
"# engine=engine,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"await vectorstore.create_schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Similarity Search with Euclidean Distance (Default)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
|
||||
"\n",
|
||||
"vectorstore = await PGVectorAsync.afrom_documents(\n",
|
||||
" embedding=embeddings,\n",
|
||||
" documents=docs,\n",
|
||||
" db_url=DATABASE_URL,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs_with_score = await vectorstore.asimilarity_search_with_score(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for doc, score in docs_with_score:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(\"Score: \", score)\n",
|
||||
" print(doc.page_content)\n",
|
||||
" print(\"-\" * 80)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Maximal Marginal Relevance Search (MMR)\n",
|
||||
"\n",
|
||||
"Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs_with_score = await vectorstore.amax_marginal_relevance_search_with_score(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for doc, score in docs_with_score:\n",
|
||||
" print(\"-\" * 80)\n",
|
||||
" print(\"Score: \", score)\n",
|
||||
" print(doc.page_content)\n",
|
||||
" print(\"-\" * 80)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Working with vectorstore\n",
|
||||
"\n",
|
||||
"Above, we created a vectorstore from scratch. However, often times we want to work with an existing vectorstore.\n",
|
||||
"In order to do that, we can initialize it directly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore = PGVectorAsync(\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" embeddings=embeddings,\n",
|
||||
" db_url=DATABASE_URL,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Add documents\n",
|
||||
"\n",
|
||||
"We can add documents to the existing vectorstore."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await vectorstore.aadd_documents(documents=[Document(page_content=\"foo\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs_with_score = await vectorstore.asimilarity_search_with_score(\"foo\", k=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"docs_with_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Overriding a vectorstore\n",
|
||||
"\n",
|
||||
"If you have an existing collection, you override it by doing `from_documents` and setting `pre_delete_collection` = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [Document(page_content=\"foo\"), Document(page_content=\"bar\")]\n",
|
||||
"vectorstore = await PGVectorAsync.afrom_documents(\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" db_url=DATABASE_URL,\n",
|
||||
" documents=docs,\n",
|
||||
" pre_delete_collection=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs_with_score = await vectorstore.asimilarity_search_with_score(\"foo\", k=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs_with_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using a VectorStore as a Retriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = vectorstore.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await retriever.aget_relevant_documents(query=\"foo\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1243
libs/langchain/langchain/vectorstores/pgvector_async.py
Normal file
1243
libs/langchain/langchain/vectorstores/pgvector_async.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -141,6 +141,8 @@ upstash-redis = {version = "^0.15.0", optional = true}
|
||||
google-cloud-documentai = {version = "^2.20.1", optional = true}
|
||||
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
|
||||
javelin-sdk = {version = "^0.1.8", optional = true}
|
||||
asyncpg = {version = "^0.28.0", optional = true}
|
||||
greenlet = {version = "^2.0.2", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@@ -367,6 +369,8 @@ extended_testing = [
|
||||
"arxiv",
|
||||
"dashvector",
|
||||
"sqlite-vss",
|
||||
"asyncpg",
|
||||
"greenlet",
|
||||
"rapidocr-onnxruntime",
|
||||
"motor",
|
||||
"timescale-vector",
|
||||
|
||||
@@ -0,0 +1,370 @@
|
||||
"""Test PGVector functionality."""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.pgvector_async import PGVectorAsync
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
DRIVER = os.environ.get("TEST_PGVECTOR_DRIVER", "asyncpg")
|
||||
HOST = os.environ.get("TEST_PGVECTOR_HOST", "localhost")
|
||||
PORT = int(os.environ.get("TEST_PGVECTOR_PORT", "5432"))
|
||||
DATABASE = os.environ.get("TEST_PGVECTOR_DATABASE", "postgres")
|
||||
USER = os.environ.get("TEST_PGVECTOR_USER", "postgres")
|
||||
PASSWORD = os.environ.get("TEST_PGVECTOR_PASSWORD", "postgres")
|
||||
|
||||
DATABASE_URL = f"postgresql+{DRIVER}://{USER}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}"
|
||||
|
||||
ADA_TOKEN_COUNT = 1536
|
||||
|
||||
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def with_db():
|
||||
vectorstore = PGVectorAsync(
|
||||
collection_name="test_collection",
|
||||
embeddings=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
)
|
||||
|
||||
await vectorstore.drop_schema()
|
||||
await vectorstore.create_schema()
|
||||
yield
|
||||
await vectorstore.drop_schema()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_embeddings() -> None:
|
||||
"""Test end to end construction with embeddings and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
|
||||
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||
docsearch = await PGVectorAsync.afrom_embeddings(
|
||||
text_embeddings=text_embedding_pairs,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_metadatas() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_metadatas_with_scores() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_score("foo", k=1)
|
||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_filter_match() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection_filter",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_score(
|
||||
"foo", k=1, filter={"page": "0"}
|
||||
)
|
||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_filter_distant_match() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection_filter",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_score(
|
||||
"foo", k=1, filter={"page": "2"}
|
||||
)
|
||||
assert output == [
|
||||
(
|
||||
Document(page_content="baz", metadata={"page": "2"}),
|
||||
0.0013003906671379406,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_filter_no_match() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection_filter",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_score(
|
||||
"foo", k=1, filter={"page": "5"}
|
||||
)
|
||||
assert output == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_collection_with_metadata() -> None:
|
||||
"""Test end to end collection construction"""
|
||||
async with with_db():
|
||||
pgvector = PGVectorAsync(
|
||||
collection_name="test_collection",
|
||||
collection_metadata={"foo": "bar"},
|
||||
embeddings=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
)
|
||||
await pgvector.delete_collection() # Delete collection if it exists
|
||||
await pgvector.create_collection()
|
||||
collection = await pgvector.get_collection()
|
||||
if collection is None:
|
||||
assert False, "Expected a CollectionStore object but received None"
|
||||
else:
|
||||
assert collection.name == "test_collection"
|
||||
assert collection.cmetadata == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_with_filter_in_set() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection_filter",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.asimilarity_search_with_score(
|
||||
"foo", k=2, filter={"page": {"IN": ["0", "2"]}}
|
||||
)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
|
||||
(
|
||||
Document(page_content="baz", metadata={"page": "2"}),
|
||||
0.0013003906671379406,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_delete_docs() -> None:
|
||||
"""Add and delete documents."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection_filter",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
ids=["1", "2", "3"],
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
await docsearch.adelete(["1", "2"])
|
||||
async with docsearch._make_session() as session:
|
||||
query = select(docsearch.EmbeddingStore)
|
||||
results = await session.execute(query)
|
||||
records = list(results.scalars().all())
|
||||
|
||||
assert sorted(record.custom_id for record in records) == ["3"]
|
||||
|
||||
await docsearch.adelete(["2", "3"]) # Should not raise on missing ids
|
||||
async with docsearch._make_session() as session:
|
||||
query = select(docsearch.EmbeddingStore)
|
||||
results = await session.execute(query)
|
||||
records = list(results.scalars().all())
|
||||
|
||||
assert sorted(record.custom_id for record in records) == [] # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_retriever_search_threshold() -> None:
|
||||
"""Test using retriever for searching with threshold."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.999},
|
||||
)
|
||||
output = await retriever.aget_relevant_documents("summer")
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": "0"}),
|
||||
Document(page_content="bar", metadata={"page": "1"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
||||
"""Test searching with threshold and custom normalization function"""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
relevance_score_fn=lambda d: d * 0,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.5},
|
||||
)
|
||||
output = await retriever.aget_relevant_documents("foo")
|
||||
assert output == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_max_marginal_relevance_search() -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.amax_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pgvector_max_marginal_relevance_search_with_score() -> None:
|
||||
"""Test max marginal relevance search with relevance scores."""
|
||||
async with with_db():
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = await PGVectorAsync.afrom_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
db_url=DATABASE_URL,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = await docsearch.amax_marginal_relevance_search_with_score(
|
||||
"foo", k=1, fetch_k=3
|
||||
)
|
||||
assert output == [(Document(page_content="foo"), 0.0)]
|
||||
Reference in New Issue
Block a user