mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
community[minor]: Update pgvecto_rs to use its high level sdk (#15574)
- **Description:** Update pgvecto_rs to use its high level sdk, - **Issue:** fix #15173
This commit is contained in:
parent
ce21392a21
commit
ddf4e7c633
@ -6,7 +6,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# PGVecto.rs\n",
|
"# PGVecto.rs\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs)). You need to install SQLAlchemy >= 2 manually."
|
"This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs))."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -15,10 +15,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"## Loading Environment Variables\n",
|
"%pip install \"pgvecto_rs[sdk]\""
|
||||||
"from dotenv import load_dotenv\n",
|
|
||||||
"\n",
|
|
||||||
"load_dotenv()"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -32,8 +29,8 @@
|
|||||||
"from langchain.docstore.document import Document\n",
|
"from langchain.docstore.document import Document\n",
|
||||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
"from langchain_community.document_loaders import TextLoader\n",
|
"from langchain_community.document_loaders import TextLoader\n",
|
||||||
"from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs\n",
|
"from langchain_community.embeddings.fake import FakeEmbeddings\n",
|
||||||
"from langchain_openai import OpenAIEmbeddings"
|
"from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -42,12 +39,12 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
|
||||||
"documents = loader.load()\n",
|
"documents = loader.load()\n",
|
||||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
"docs = text_splitter.split_documents(documents)\n",
|
"docs = text_splitter.split_documents(documents)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"embeddings = OpenAIEmbeddings()"
|
"embeddings = FakeEmbeddings(size=3)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -176,7 +173,17 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
"docs: List[Document] = db1.similarity_search(query, k=4)"
|
"docs: List[Document] = db1.similarity_search(query, k=4)\n",
|
||||||
|
"for doc in docs:\n",
|
||||||
|
" print(doc.page_content)\n",
|
||||||
|
" print(\"======================\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Similarity Search with Filter"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -185,6 +192,36 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from pgvecto_rs.sdk.filters import meta_contains\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs: List[Document] = db1.similarity_search(\n",
|
||||||
|
" query, k=4, filter=meta_contains({\"source\": \"../../modules/state_of_the_union.txt\"})\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"for doc in docs:\n",
|
||||||
|
" print(doc.page_content)\n",
|
||||||
|
" print(\"======================\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Or:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs: List[Document] = db1.similarity_search(\n",
|
||||||
|
" query, k=4, filter={\"source\": \"../../modules/state_of_the_union.txt\"}\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
"for doc in docs:\n",
|
"for doc in docs:\n",
|
||||||
" print(doc.page_content)\n",
|
" print(doc.page_content)\n",
|
||||||
" print(\"======================\")"
|
" print(\"======================\")"
|
||||||
@ -207,7 +244,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.3"
|
"version": "3.11.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,32 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, List, Literal, Optional, Tuple, Type
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import sqlalchemy
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from sqlalchemy import insert, select
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
||||||
from sqlalchemy.orm.session import Session
|
|
||||||
|
|
||||||
|
|
||||||
class _ORMBase(DeclarativeBase):
|
|
||||||
__tablename__: str
|
|
||||||
id: Mapped[uuid.UUID]
|
|
||||||
text: Mapped[str]
|
|
||||||
meta: Mapped[dict]
|
|
||||||
embedding: Mapped[np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
class PGVecto_rs(VectorStore):
|
class PGVecto_rs(VectorStore):
|
||||||
"""VectorStore backed by pgvecto_rs."""
|
"""VectorStore backed by pgvecto_rs."""
|
||||||
|
|
||||||
_engine: sqlalchemy.engine.Engine
|
_store = None
|
||||||
_table: Type[_ORMBase]
|
|
||||||
_embedding: Embeddings
|
_embedding: Embeddings
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -45,28 +30,22 @@ class PGVecto_rs(VectorStore):
|
|||||||
db_url: Database URL.
|
db_url: Database URL.
|
||||||
collection_name: Name of the collection.
|
collection_name: Name of the collection.
|
||||||
new_table: Whether to create a new table or connect to an existing one.
|
new_table: Whether to create a new table or connect to an existing one.
|
||||||
|
If true, the table will be dropped if exists, then recreated.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from pgvecto_rs.sqlalchemy import Vector
|
from pgvecto_rs.sdk import PGVectoRs
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Unable to import pgvector_rs, please install with "
|
"Unable to import pgvector_rs.sdk , please install with "
|
||||||
"`pip install pgvector_rs`."
|
'`pip install "pgvector_rs[sdk]"`.'
|
||||||
) from e
|
) from e
|
||||||
|
self._store = PGVectoRs(
|
||||||
class _Table(_ORMBase):
|
db_url=db_url,
|
||||||
__tablename__ = f"collection_{collection_name}"
|
collection_name=collection_name,
|
||||||
id: Mapped[uuid.UUID] = mapped_column(
|
dimension=dimension,
|
||||||
postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
recreate=new_table,
|
||||||
)
|
)
|
||||||
text: Mapped[str] = mapped_column(sqlalchemy.String)
|
|
||||||
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
|
|
||||||
embedding: Mapped[np.ndarray] = mapped_column(Vector(dimension))
|
|
||||||
|
|
||||||
self._engine = sqlalchemy.create_engine(db_url)
|
|
||||||
self._table = _Table
|
|
||||||
self._table.__table__.create(self._engine, checkfirst=not new_table) # type: ignore
|
|
||||||
self._embedding = embedding
|
self._embedding = embedding
|
||||||
|
|
||||||
# ================ Create interface =================
|
# ================ Create interface =================
|
||||||
@ -90,7 +69,6 @@ class PGVecto_rs(VectorStore):
|
|||||||
dimension=dimension,
|
dimension=dimension,
|
||||||
db_url=db_url,
|
db_url=db_url,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
new_table=True,
|
|
||||||
)
|
)
|
||||||
_self.add_texts(texts, metadatas, **kwargs)
|
_self.add_texts(texts, metadatas, **kwargs)
|
||||||
return _self
|
return _self
|
||||||
@ -148,19 +126,15 @@ class PGVecto_rs(VectorStore):
|
|||||||
List of ids of the added texts.
|
List of ids of the added texts.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from pgvecto_rs.sdk import Record
|
||||||
|
|
||||||
embeddings = self._embedding.embed_documents(list(texts))
|
embeddings = self._embedding.embed_documents(list(texts))
|
||||||
with Session(self._engine) as _session:
|
records = [
|
||||||
results: List[str] = []
|
Record.from_text(text, embedding, meta)
|
||||||
for text, embedding, metadata in zip(
|
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
|
||||||
texts, embeddings, metadatas or [dict()] * len(list(texts))
|
]
|
||||||
):
|
self._store.insert(records)
|
||||||
t = insert(self._table).values(
|
return [str(record.id) for record in records]
|
||||||
text=text, meta=metadata, embedding=embedding
|
|
||||||
)
|
|
||||||
id = _session.execute(t).inserted_primary_key[0] # type: ignore
|
|
||||||
results.append(str(id))
|
|
||||||
_session.commit()
|
|
||||||
return results
|
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Run more documents through the embeddings and add to the vectorstore.
|
"""Run more documents through the embeddings and add to the vectorstore.
|
||||||
@ -185,30 +159,40 @@ class PGVecto_rs(VectorStore):
|
|||||||
distance_func: Literal[
|
distance_func: Literal[
|
||||||
"sqrt_euclid", "neg_dot_prod", "ned_cos"
|
"sqrt_euclid", "neg_dot_prod", "ned_cos"
|
||||||
] = "sqrt_euclid",
|
] = "sqrt_euclid",
|
||||||
|
filter: Union[None, Dict[str, Any], Any] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs most similar to query vector, with its score."""
|
"""Return docs most similar to query vector, with its score."""
|
||||||
with Session(self._engine) as _session:
|
|
||||||
real_distance_func = (
|
|
||||||
self._table.embedding.squared_euclidean_distance
|
|
||||||
if distance_func == "sqrt_euclid"
|
|
||||||
else self._table.embedding.negative_dot_product_distance
|
|
||||||
if distance_func == "neg_dot_prod"
|
|
||||||
else self._table.embedding.negative_cosine_distance
|
|
||||||
if distance_func == "ned_cos"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if real_distance_func is None:
|
|
||||||
raise ValueError("Invalid distance function")
|
|
||||||
|
|
||||||
t = (
|
from pgvecto_rs.sdk.filters import meta_contains
|
||||||
select(self._table, real_distance_func(query_vector).label("score"))
|
|
||||||
.order_by("score")
|
distance_func_map = {
|
||||||
.limit(k) # type: ignore
|
"sqrt_euclid": "<->",
|
||||||
|
"neg_dot_prod": "<#>",
|
||||||
|
"ned_cos": "<=>",
|
||||||
|
}
|
||||||
|
if filter is None:
|
||||||
|
real_filter = None
|
||||||
|
elif isinstance(filter, dict):
|
||||||
|
real_filter = meta_contains(filter)
|
||||||
|
else:
|
||||||
|
real_filter = filter
|
||||||
|
results = self._store.search(
|
||||||
|
query_vector,
|
||||||
|
distance_func_map[distance_func],
|
||||||
|
k,
|
||||||
|
filter=real_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(Document(page_content=row[0].text, metadata=row[0].meta), row[1])
|
(
|
||||||
for row in _session.execute(t)
|
Document(
|
||||||
|
page_content=res[0].text,
|
||||||
|
metadata=res[0].meta,
|
||||||
|
),
|
||||||
|
res[1],
|
||||||
|
)
|
||||||
|
for res in results
|
||||||
]
|
]
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
@ -218,11 +202,12 @@ class PGVecto_rs(VectorStore):
|
|||||||
distance_func: Literal[
|
distance_func: Literal[
|
||||||
"sqrt_euclid", "neg_dot_prod", "ned_cos"
|
"sqrt_euclid", "neg_dot_prod", "ned_cos"
|
||||||
] = "sqrt_euclid",
|
] = "sqrt_euclid",
|
||||||
|
filter: Optional[Any] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return [
|
return [
|
||||||
doc
|
doc
|
||||||
for doc, score in self.similarity_search_with_score_by_vector(
|
for doc, _score in self.similarity_search_with_score_by_vector(
|
||||||
embedding, k, distance_func, **kwargs
|
embedding, k, distance_func, **kwargs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -254,7 +239,7 @@ class PGVecto_rs(VectorStore):
|
|||||||
query_vector = self._embedding.embed_query(query)
|
query_vector = self._embedding.embed_query(query)
|
||||||
return [
|
return [
|
||||||
doc
|
doc
|
||||||
for doc, score in self.similarity_search_with_score_by_vector(
|
for doc, _score in self.similarity_search_with_score_by_vector(
|
||||||
query_vector, k, distance_func, **kwargs
|
query_vector, k, distance_func, **kwargs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user