From 34a24d4df62960139c924346b4e3bc8eb32696d0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 8 Apr 2024 09:34:10 -0400 Subject: [PATCH] postgres[minor]: Add pgvector community as is (#20096) This moves langchain pgvector community as is The only modification is support for psycopg3 rather than psycopg2! --- .../postgres/langchain_postgres/_utils.py | 82 + .../langchain_postgres/vectorstores.py | 1349 +++++++++++++++++ libs/partners/postgres/poetry.lock | 230 ++- libs/partners/postgres/pyproject.toml | 3 + .../integration_tests/fake_embeddings.py | 28 + .../integration_tests/fixtures/__init__.py | 0 .../fixtures/filtering_test_cases.py | 218 +++ .../integration_tests/test_vectorstore.py | 505 ++++++ 8 files changed, 2408 insertions(+), 7 deletions(-) create mode 100644 libs/partners/postgres/langchain_postgres/_utils.py create mode 100644 libs/partners/postgres/langchain_postgres/vectorstores.py create mode 100644 libs/partners/postgres/tests/integration_tests/fake_embeddings.py create mode 100644 libs/partners/postgres/tests/integration_tests/fixtures/__init__.py create mode 100644 libs/partners/postgres/tests/integration_tests/fixtures/filtering_test_cases.py create mode 100644 libs/partners/postgres/tests/integration_tests/test_vectorstore.py diff --git a/libs/partners/postgres/langchain_postgres/_utils.py b/libs/partners/postgres/langchain_postgres/_utils.py new file mode 100644 index 00000000000..9d8055af7ab --- /dev/null +++ b/libs/partners/postgres/langchain_postgres/_utils.py @@ -0,0 +1,82 @@ +"""Copied over from langchain_community. + +This code should be moved to langchain proper or removed entirely. +""" + +import logging +from typing import List, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + + +def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices.""" + if len(X) == 0 or len(Y) == 0: + return np.array([]) + + X = np.array(X) + Y = np.array(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError( + f"Number of columns in X and Y must be the same. X has shape {X.shape} " + f"and Y has shape {Y.shape}." + ) + try: + import simsimd as simd # type: ignore + + X = np.array(X, dtype=np.float32) + Y = np.array(Y, dtype=np.float32) + Z = 1 - simd.cdist(X, Y, metric="cosine") + if isinstance(Z, float): + return np.array([Z]) + return np.array(Z) + except ImportError: + logger.debug( + "Unable to import simsimd, defaulting to NumPy implementation. If you want " + "to use simsimd please install with `pip install simsimd`." + ) + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity + + +def maximal_marginal_relevance( + query_embedding: np.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> List[int]: + """Calculate maximal marginal relevance.""" + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding.ndim == 1: + query_embedding = np.expand_dims(query_embedding, axis=0) + similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs diff --git a/libs/partners/postgres/langchain_postgres/vectorstores.py b/libs/partners/postgres/langchain_postgres/vectorstores.py new file mode 100644 index 00000000000..6750fe7a258 --- /dev/null +++ b/libs/partners/postgres/langchain_postgres/vectorstores.py @@ -0,0 +1,1349 @@ +from __future__ import annotations + +import contextlib +import enum +import logging +import uuid +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import numpy as np +import sqlalchemy +from langchain_core._api import warn_deprecated +from sqlalchemy import SQLColumnExpression, cast, delete, func +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID +from sqlalchemy.orm import Session, relationship + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.runnables.config import run_in_executor +from langchain_core.utils import get_from_dict_or_env +from langchain_core.vectorstores import VectorStore + +from langchain_postgres._utils import maximal_marginal_relevance + + +class DistanceStrategy(str, enum.Enum): + """Enumerator of the Distance strategies.""" + + EUCLIDEAN = "l2" + COSINE = "cosine" + MAX_INNER_PRODUCT = "inner" + + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE + +Base = declarative_base() # type: Any + + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + + +_classes: Any = None + +COMPARISONS_TO_NATIVE = { + "$eq": "==", + "$ne": "!=", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + + +def _get_embedding_collection_store( + vector_dimension: Optional[int] = None, *, use_jsonb: bool = True +) -> Any: + global _classes + if _classes is not None: + return _classes + + from pgvector.sqlalchemy import Vector # type: ignore + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "langchain_pg_collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata = sqlalchemy.Column(JSON) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + return session.query(cls).filter(cls.name == name).first() # type: ignore + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ # noqa: E501 + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + if use_jsonb: + # TODO(PRIOR TO LANDING): Create a gin index on the cmetadata field + class EmbeddingStore(BaseModel): + """Embedding store.""" + + __tablename__ = "langchain_pg_embedding" + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSONB, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + __table_args__ = ( + sqlalchemy.Index( + "ix_cmetadata_gin", + "cmetadata", + postgresql_using="gin", + postgresql_ops={"cmetadata": "jsonb_path_ops"}, + ), + ) + else: + # For backwards comaptibilty with older versions of pgvector + # This should be removed in the future (remove during migration) + class EmbeddingStore(BaseModel): # type: ignore[no-redef] + """Embedding store.""" + + __tablename__ = "langchain_pg_embedding" + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSON, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + _classes = (EmbeddingStore, CollectionStore) + + return _classes + + +def _results_to_docs(docs_and_scores: Any) -> List[Document]: + """Return docs from docs and scores.""" + return [doc for doc, _ in docs_and_scores] + + +class PGVector(VectorStore): + """`Postgres`/`PGVector` vector store. + + To use, you should have the ``pgvector`` python package installed. + + Example: + .. code-block:: python + + from langchain_postgres.vectorstores import PGVector + from langchain_community.embeddings.openai import OpenAIEmbeddings + + CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3" + COLLECTION_NAME = "state_of_the_union_test" + embeddings = OpenAIEmbeddings() + vectorestore = PGVector.from_documents( + embedding=embeddings, + documents=docs, + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + use_jsonb=True, + ) + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + embedding_length: Optional[int] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + *, + connection: Optional[sqlalchemy.engine.Connection] = None, + engine_args: Optional[dict[str, Any]] = None, + use_jsonb: bool = False, + create_extension: bool = True, + ) -> None: + """Initialize the PGVector store. + + Args: + connection_string: Postgres connection string. + embedding_function: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + embedding_length: The length of the embedding vector. (default: None) + NOTE: This is not mandatory. Defining it will prevent vectors of + any other size to be added to the embeddings table but, without it, + the embeddings can't be indexed. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: COSINE) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + engine_args: SQLAlchemy's create engine arguments. + use_jsonb: Use JSONB instead of JSON for metadata. (default: True) + Strongly discouraged from using JSON as it's not as efficient + for querying. + It's provided here for backwards compatibility with older versions, + and will be removed in the future. + create_extension: If True, will create the vector extension if it + doesn't exist. disabling creation is useful when using ReadOnly + Databases. + """ + self.connection_string = connection_string + self.embedding_function = embedding_function + self._embedding_length = embedding_length + self.collection_name = collection_name + self.collection_metadata = collection_metadata + self._distance_strategy = distance_strategy + self.pre_delete_collection = pre_delete_collection + self.logger = logger or logging.getLogger(__name__) + self.override_relevance_score_fn = relevance_score_fn + self.engine_args = engine_args or {} + self._bind = connection if connection else self._create_engine() + self.use_jsonb = use_jsonb + self.create_extension = create_extension + + if not use_jsonb: + # Replace with a deprecation warning. + warn_deprecated( + "0.0.29", + pending=True, + message=( + "Please use JSONB instead of JSON for metadata. " + "This change will allow for more efficient querying that " + "involves filtering based on metadata." + "Please note that filtering operators have been changed " + "when using JSOB metadata to be prefixed with a $ sign " + "to avoid name collisions with columns. " + "If you're using an existing database, you will need to create a" + "db migration for your metadata column to be JSONB and update your " + "queries to use the new operators. " + ), + alternative=( + "Instantiate with use_jsonb=True to use JSONB instead " + "of JSON for metadata." + ), + ) + self.__post_init__() + + def __post_init__( + self, + ) -> None: + """Initialize the store.""" + if self.create_extension: + self.create_vector_extension() + + EmbeddingStore, CollectionStore = _get_embedding_collection_store( + self._embedding_length, use_jsonb=self.use_jsonb + ) + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore + self.create_tables_if_not_exists() + self.create_collection() + + def __del__(self) -> None: + if isinstance(self._bind, sqlalchemy.engine.Connection): + self._bind.close() + + @property + def embeddings(self) -> Embeddings: + return self.embedding_function + + def _create_engine(self) -> sqlalchemy.engine.Engine: + return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args) + + def create_vector_extension(self) -> None: + try: + with Session(self._bind) as session: # type: ignore[arg-type] + # The advisor lock fixes issue arising from concurrent + # creation of the vector extension. + # https://github.com/langchain-ai/langchain/issues/12933 + # For more information see: + # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + statement = sqlalchemy.text( + "BEGIN;" + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + "COMMIT;" + ) + session.execute(statement) + session.commit() + except Exception as e: + raise Exception(f"Failed to create vector extension: {e}") from e + + def create_tables_if_not_exists(self) -> None: + with Session(self._bind) as session, session.begin(): # type: ignore[arg-type] + Base.metadata.create_all(session.get_bind()) + + def drop_tables(self) -> None: + with Session(self._bind) as session, session.begin(): # type: ignore[arg-type] + Base.metadata.drop_all(session.get_bind()) + + def create_collection(self) -> None: + if self.pre_delete_collection: + self.delete_collection() + with Session(self._bind) as session: # type: ignore[arg-type] + self.CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + def delete_collection(self) -> None: + self.logger.debug("Trying to delete collection") + with Session(self._bind) as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + session.commit() + + @contextlib.contextmanager + def _make_session(self) -> Generator[Session, None, None]: + """Create a context manager for the session, bind to _conn string.""" + yield Session(self._bind) # type: ignore[arg-type] + + def delete( + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, + ) -> None: + """Delete vectors by ids or uuids. + + Args: + ids: List of ids to delete. + collection_only: Only delete ids in the collection. + """ + with Session(self._bind) as session: # type: ignore[arg-type] + if ids is not None: + self.logger.debug( + "Trying to delete vectors by ids (represented by the model " + "using the custom ids field)" + ) + + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + + stmt = stmt.where(self.EmbeddingStore.custom_id.in_(ids)) + session.execute(stmt) + session.commit() + + def get_collection(self, session: Session) -> Any: + return self.CollectionStore.get_by_name(session, self.collection_name) + + @classmethod + def __from( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection_string: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + if connection_string is None: + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + + store.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + with Session(self._bind) as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + documents = [] + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = self.EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + collection_id=collection.uuid, + ) + documents.append(embedding_store) + session.bulk_save_objects(documents) + session.commit() + + return ids + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + embeddings = self.embedding_function.embed_documents(list(texts)) + return self.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with PGVector with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding_function.embed_query(text=query) + return self.similarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + + @property + def distance_strategy(self) -> Any: + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self.EmbeddingStore.embedding.l2_distance + elif self._distance_strategy == DistanceStrategy.COSINE: + return self.EmbeddingStore.embedding.cosine_distance + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self.EmbeddingStore.embedding.max_inner_product + else: + raise ValueError( + f"Got unexpected value for distance: {self._distance_strategy}. " + f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." + ) + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + results = self.__query_collection(embedding=embedding, k=k, filter=filter) + + return self._results_to_docs_and_scores(results) + + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result.distance if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def _handle_field_filter( + self, + field: str, + value: Any, + ) -> SQLColumnExpression: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + + Returns: + sqlalchemy expression + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError( + f"Invalid field name: {field}. Expected a valid identifier." + ) + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + native = COMPARISONS_TO_NATIVE[operator] + return func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} {native} $value", JSONPATH), + cast({"value": filter_value}, JSONB), + ) + elif operator == "$between": + # Use AND with two comparisons + low, high = filter_value + + lower_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} >= $value", JSONPATH), + cast({"value": low}, JSONB), + ) + upper_bound = func.jsonb_path_match( + self.EmbeddingStore.cmetadata, + cast(f"$.{field} <= $value", JSONPATH), + cast({"value": high}, JSONB), + ) + return sqlalchemy.and_(lower_bound, upper_bound) + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + queried_field = self.EmbeddingStore.cmetadata[field].astext + + if operator in {"$in"}: + return queried_field.in_([str(val) for val in filter_value]) + elif operator in {"$nin"}: + return queried_field.nin_([str(val) for val in filter_value]) + elif operator in {"$like"}: + return queried_field.like(filter_value) + elif operator in {"$ilike"}: + return queried_field.ilike(filter_value) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] + """Deprecated functionality. + + This is for backwards compatibility with the JSON based schema for metadata. + It uses incorrect operator syntax (operators are not prefixed with $). + + This implementation is not efficient, and has bugs associated with + the way that it handles numeric filter clauses. + """ + IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" + EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" + + value_case_insensitive = {k.lower(): v for k, v in value.items()} + if IN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_( + value_case_insensitive[IN] + ) + elif NIN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in( + value_case_insensitive[NIN] + ) + elif BETWEEN in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between( + str(value_case_insensitive[BETWEEN][0]), + str(value_case_insensitive[BETWEEN][1]), + ) + elif GT in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str( + value_case_insensitive[GT] + ) + elif LT in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str( + value_case_insensitive[LT] + ) + elif NE in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str( + value_case_insensitive[NE] + ) + elif EQ in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( + value_case_insensitive[EQ] + ) + elif LIKE in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like( + value_case_insensitive[LIKE] + ) + elif CONTAINS in map(str.lower, value): + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains( + value_case_insensitive[CONTAINS] + ) + elif OR in map(str.lower, value): + or_clauses = [ + self._create_filter_clause(key, sub_value) + for sub_value in value_case_insensitive[OR] + ] + filter_by_metadata = sqlalchemy.or_(*or_clauses) + elif AND in map(str.lower, value): + and_clauses = [ + self._create_filter_clause(key, sub_value) + for sub_value in value_case_insensitive[AND] + ] + filter_by_metadata = sqlalchemy.and_(*and_clauses) + + else: + filter_by_metadata = None + + return filter_by_metadata + + def _create_filter_clause_json_deprecated( + self, filter: Any + ) -> List[SQLColumnExpression]: + """Convert filters from IR to SQL clauses. + + **DEPRECATED** This functionality will be deprecated in the future. + + It implements translation of filters for a schema that uses JSON + for metadata rather than the JSONB field which is more efficient + for querying. + """ + filter_clauses = [] + for key, value in filter.items(): + if isinstance(value, dict): + filter_by_metadata = self._create_filter_clause_deprecated(key, value) + + if filter_by_metadata is not None: + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( + value + ) + filter_clauses.append(filter_by_metadata) + return filter_clauses + + def _create_filter_clause(self, filters: Any) -> Any: + """Convert LangChain IR filter representation to matching SQLAlchemy clauses. + + At the top level, we still don't know if we're working with a field + or an operator for the keys. After we've determined that we can + call the appropriate logic to handle filter creation. + + Args: + filters: Dictionary of filters to apply to the query. + + Returns: + SQLAlchemy clause to apply to the query. + """ + if isinstance(filters, dict): + if len(filters) == 1: + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or"]: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + else: + # Then it's a field + return self._handle_field_filter(key, filters[key]) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == "$and": + and_ = [self._create_filter_clause(el) for el in value] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$or": + or_ = [self._create_filter_clause(el) for el in value] + if len(or_) > 1: + return sqlalchemy.or_(*or_) + elif len(or_) == 1: + return or_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] + if len(and_) > 1: + return sqlalchemy.and_(*and_) + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError("Got an empty dictionary for filters.") + else: + raise ValueError( + f"Invalid type: Expected a dictionary but got type: {type(filters)}" + ) + + def __query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + with Session(self._bind) as session: # type: ignore[arg-type] + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) + + _type = self.EmbeddingStore + + results: List[Any] = ( + session.query( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), # type: ignore + ) + .filter(*filter_by) + .order_by(sqlalchemy.asc("distance")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + .all() + ) + + return results + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query vector. + """ + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return _results_to_docs(docs_and_scores) + + @classmethod + def from_texts( + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, + **kwargs: Any, + ) -> PGVector: + """ + Return VectorStore initialized from texts and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + embeddings = embedding.embed_documents(list(texts)) + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + + @classmethod + def from_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain_community.vectorstores import PGVector + from langchain_community.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def from_existing_index( + cls: Type[PGVector], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """ + Get instance of an existing PGVector store.This method will + return the instance of the store without inserting any new + embeddings + """ + + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + ) + + return store + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="PGVECTOR_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Postgres connection string is required" + "Either pass it as a parameter" + "or set the PGVECTOR_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def from_documents( + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = False, + **kwargs: Any, + ) -> PGVector: + """ + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + connection_string = cls.get_connection_string(kwargs) + + kwargs["connection_string"] = connection_string + + return cls.from_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + distance_strategy=distance_strategy, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + use_jsonb=use_jsonb, + **kwargs, + ) + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}" + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return self._cosine_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function" + f" for distance_strategy of {self._distance_strategy}." + "Consider providing relevance_score_fn to PGVector constructor." + ) + + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + + embedding_list = [result.EmbeddingStore.embedding for result in results] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embedding_list, + k=k, + lambda_mult=lambda_mult, + ) + + candidates = self._results_to_docs_and_scores(results) + + return [r for i, r in enumerate(candidates) if i in mmr_selected] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + embedding = self.embedding_function.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + def max_marginal_relevance_search_with_score( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + embedding = self.embedding_function.embed_query(query) + docs = self.max_marginal_relevance_search_with_score_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + return docs + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + + return _results_to_docs(docs_and_scores) + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + return await run_in_executor( + None, + self.max_marginal_relevance_search_by_vector, + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) diff --git a/libs/partners/postgres/poetry.lock b/libs/partners/postgres/poetry.lock index 6c4ff0070d9..8508a0b4aef 100644 --- a/libs/partners/postgres/poetry.lock +++ b/libs/partners/postgres/poetry.lock @@ -163,6 +163,77 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "greenlet" +version = "3.0.3" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.7" +files = [ + {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, + {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, + {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, + {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, + {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, + {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, + {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, + {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, + {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, + {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, + {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, + {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, + {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, + {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, + {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, + {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil"] + [[package]] name = "idna" version = "3.6" @@ -250,13 +321,13 @@ langchain-core = ">=0.1.38,<0.2.0" [[package]] name = "langsmith" -version = "0.1.38" +version = "0.1.40" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.38-py3-none-any.whl", hash = "sha256:f36479f82cf537cf40d129ac2e485e72a3981360c7b6cf2549dad77d98eafd8f"}, - {file = "langsmith-0.1.38.tar.gz", hash = "sha256:2c1f98ac0a8c02e43b625650a6e13c65b09523551bfc21a59d20963f46f7d265"}, + {file = "langsmith-0.1.40-py3-none-any.whl", hash = "sha256:aa47d0f5a1eabd5c05ac6ce2cd3e28ccfc554d366e856a27b7c3c17c443881cb"}, + {file = "langsmith-0.1.40.tar.gz", hash = "sha256:50fdf313741cf94e978de06025fd180b56acf1d1a4549b0fd5453ef23d5461ef"}, ] [package.dependencies] @@ -322,6 +393,51 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "orjson" version = "3.10.0" @@ -393,6 +509,19 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] +[[package]] +name = "pgvector" +version = "0.2.5" +description = "pgvector support for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pgvector-0.2.5-py2.py3-none-any.whl", hash = "sha256:5e5e93ec4d3c45ab1fa388729d56c602f6966296e19deee8878928c6d567e41b"}, +] + +[package.dependencies] +numpy = "*" + [[package]] name = "pluggy" version = "1.4.0" @@ -701,6 +830,93 @@ files = [ {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, ] +[[package]] +name = "sqlalchemy" +version = "2.0.29" +description = "Database Abstraction Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "SQLAlchemy-2.0.29-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4c142852ae192e9fe5aad5c350ea6befe9db14370b34047e1f0f7cf99e63c63b"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:99a1e69d4e26f71e750e9ad6fdc8614fbddb67cfe2173a3628a2566034e223c7"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ef3fbccb4058355053c51b82fd3501a6e13dd808c8d8cd2561e610c5456013c"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d6753305936eddc8ed190e006b7bb33a8f50b9854823485eed3a886857ab8d1"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0f3ca96af060a5250a8ad5a63699180bc780c2edf8abf96c58af175921df847a"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c4520047006b1d3f0d89e0532978c0688219857eb2fee7c48052560ae76aca1e"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-win32.whl", hash = "sha256:b2a0e3cf0caac2085ff172c3faacd1e00c376e6884b5bc4dd5b6b84623e29e4f"}, + {file = "SQLAlchemy-2.0.29-cp310-cp310-win_amd64.whl", hash = "sha256:01d10638a37460616708062a40c7b55f73e4d35eaa146781c683e0fa7f6c43fb"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:308ef9cb41d099099fffc9d35781638986870b29f744382904bf9c7dadd08513"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:296195df68326a48385e7a96e877bc19aa210e485fa381c5246bc0234c36c78e"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a13b917b4ffe5a0a31b83d051d60477819ddf18276852ea68037a144a506efb9"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f6d971255d9ddbd3189e2e79d743ff4845c07f0633adfd1de3f63d930dbe673"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:61405ea2d563407d316c63a7b5271ae5d274a2a9fbcd01b0aa5503635699fa1e"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:de7202ffe4d4a8c1e3cde1c03e01c1a3772c92858837e8f3879b497158e4cb44"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-win32.whl", hash = "sha256:b5d7ed79df55a731749ce65ec20d666d82b185fa4898430b17cb90c892741520"}, + {file = "SQLAlchemy-2.0.29-cp311-cp311-win_amd64.whl", hash = "sha256:205f5a2b39d7c380cbc3b5dcc8f2762fb5bcb716838e2d26ccbc54330775b003"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d96710d834a6fb31e21381c6d7b76ec729bd08c75a25a5184b1089141356171f"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:52de4736404e53c5c6a91ef2698c01e52333988ebdc218f14c833237a0804f1b"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c7b02525ede2a164c5fa5014915ba3591730f2cc831f5be9ff3b7fd3e30958e"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dfefdb3e54cd15f5d56fd5ae32f1da2d95d78319c1f6dfb9bcd0eb15d603d5d"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a88913000da9205b13f6f195f0813b6ffd8a0c0c2bd58d499e00a30eb508870c"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fecd5089c4be1bcc37c35e9aa678938d2888845a134dd016de457b942cf5a758"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-win32.whl", hash = "sha256:8197d6f7a3d2b468861ebb4c9f998b9df9e358d6e1cf9c2a01061cb9b6cf4e41"}, + {file = "SQLAlchemy-2.0.29-cp312-cp312-win_amd64.whl", hash = "sha256:9b19836ccca0d321e237560e475fd99c3d8655d03da80c845c4da20dda31b6e1"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:87a1d53a5382cdbbf4b7619f107cc862c1b0a4feb29000922db72e5a66a5ffc0"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a0732dffe32333211801b28339d2a0babc1971bc90a983e3035e7b0d6f06b93"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90453597a753322d6aa770c5935887ab1fc49cc4c4fdd436901308383d698b4b"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ea311d4ee9a8fa67f139c088ae9f905fcf0277d6cd75c310a21a88bf85e130f5"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:5f20cb0a63a3e0ec4e169aa8890e32b949c8145983afa13a708bc4b0a1f30e03"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-win32.whl", hash = "sha256:e5bbe55e8552019c6463709b39634a5fc55e080d0827e2a3a11e18eb73f5cdbd"}, + {file = "SQLAlchemy-2.0.29-cp37-cp37m-win_amd64.whl", hash = "sha256:c2f9c762a2735600654c654bf48dad388b888f8ce387b095806480e6e4ff6907"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e614d7a25a43a9f54fcce4675c12761b248547f3d41b195e8010ca7297c369c"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:471fcb39c6adf37f820350c28aac4a7df9d3940c6548b624a642852e727ea586"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:988569c8732f54ad3234cf9c561364221a9e943b78dc7a4aaf35ccc2265f1930"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dddaae9b81c88083e6437de95c41e86823d150f4ee94bf24e158a4526cbead01"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:334184d1ab8f4c87f9652b048af3f7abea1c809dfe526fb0435348a6fef3d380"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:38b624e5cf02a69b113c8047cf7f66b5dfe4a2ca07ff8b8716da4f1b3ae81567"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-win32.whl", hash = "sha256:bab41acf151cd68bc2b466deae5deeb9e8ae9c50ad113444151ad965d5bf685b"}, + {file = "SQLAlchemy-2.0.29-cp38-cp38-win_amd64.whl", hash = "sha256:52c8011088305476691b8750c60e03b87910a123cfd9ad48576d6414b6ec2a1d"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3071ad498896907a5ef756206b9dc750f8e57352113c19272bdfdc429c7bd7de"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dba622396a3170974f81bad49aacebd243455ec3cc70615aeaef9e9613b5bca5"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b184e3de58009cc0bf32e20f137f1ec75a32470f5fede06c58f6c355ed42a72"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c37f1050feb91f3d6c32f864d8e114ff5545a4a7afe56778d76a9aec62638ba"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bda7ce59b06d0f09afe22c56714c65c957b1068dee3d5e74d743edec7daba552"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:25664e18bef6dc45015b08f99c63952a53a0a61f61f2e48a9e70cec27e55f699"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-win32.whl", hash = "sha256:77d29cb6c34b14af8a484e831ab530c0f7188f8efed1c6a833a2c674bf3c26ec"}, + {file = "SQLAlchemy-2.0.29-cp39-cp39-win_amd64.whl", hash = "sha256:04c487305ab035a9548f573763915189fc0fe0824d9ba28433196f8436f1449c"}, + {file = "SQLAlchemy-2.0.29-py3-none-any.whl", hash = "sha256:dc4ee2d4ee43251905f88637d5281a8d52e916a021384ec10758826f5cbae305"}, + {file = "SQLAlchemy-2.0.29.tar.gz", hash = "sha256:bd9566b8e58cabd700bc367b60e90d9349cd16f0984973f98a9a09f9c64e86f0"}, +] + +[package.dependencies] +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +typing-extensions = ">=4.6.0" + +[package.extras] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] +asyncio = ["greenlet (!=0.4.17)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mssql = ["pyodbc"] +mssql-pymssql = ["pymssql"] +mssql-pyodbc = ["pyodbc"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] +mysql-connector = ["mysql-connector-python"] +oracle = ["cx_oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] +postgresql = ["psycopg2 (>=2.7)"] +postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] +postgresql-psycopg2binary = ["psycopg2-binary"] +postgresql-psycopg2cffi = ["psycopg2cffi"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] +sqlcipher = ["sqlcipher3_binary"] + [[package]] name = "tenacity" version = "8.2.3" @@ -728,13 +944,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, + {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] [[package]] @@ -768,4 +984,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ee9808589dabaecefbb3b06d09e0c7a172116173ca9ea0de28263396793f377a" +content-hash = "02a20cf8f1209824252361c78bffcdfa960bf92ef3214807cc9f494eb533b7e4" diff --git a/libs/partners/postgres/pyproject.toml b/libs/partners/postgres/pyproject.toml index c9a598d5ca6..25be0deb997 100644 --- a/libs/partners/postgres/pyproject.toml +++ b/libs/partners/postgres/pyproject.toml @@ -16,6 +16,9 @@ langchain-core = "^0.1" psycopg = "^3.1.18" langgraph = "^0.0.32" psycopg-pool = "^3.2.1" +sqlalchemy = "^2.0.29" +pgvector = "^0.2.5" +numpy = "^1.26.4" [tool.poetry.group.test] optional = true diff --git a/libs/partners/postgres/tests/integration_tests/fake_embeddings.py b/libs/partners/postgres/tests/integration_tests/fake_embeddings.py new file mode 100644 index 00000000000..81fd2aa5ae6 --- /dev/null +++ b/libs/partners/postgres/tests/integration_tests/fake_embeddings.py @@ -0,0 +1,28 @@ +"""Copied from community.""" +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) diff --git a/libs/partners/postgres/tests/integration_tests/fixtures/__init__.py b/libs/partners/postgres/tests/integration_tests/fixtures/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/partners/postgres/tests/integration_tests/fixtures/filtering_test_cases.py b/libs/partners/postgres/tests/integration_tests/fixtures/filtering_test_cases.py new file mode 100644 index 00000000000..9dcca44f563 --- /dev/null +++ b/libs/partners/postgres/tests/integration_tests/fixtures/filtering_test_cases.py @@ -0,0 +1,218 @@ +"""Module needs to move to a stasndalone package.""" +from langchain_core.documents import Document + +metadatas = [ + { + "name": "adam", + "date": "2021-01-01", + "count": 1, + "is_active": True, + "tags": ["a", "b"], + "location": [1.0, 2.0], + "id": 1, + "height": 10.0, # Float column + "happiness": 0.9, # Float column + "sadness": 0.1, # Float column + }, + { + "name": "bob", + "date": "2021-01-02", + "count": 2, + "is_active": False, + "tags": ["b", "c"], + "location": [2.0, 3.0], + "id": 2, + "height": 5.7, # Float column + "happiness": 0.8, # Float column + "sadness": 0.1, # Float column + }, + { + "name": "jane", + "date": "2021-01-01", + "count": 3, + "is_active": True, + "tags": ["b", "d"], + "location": [3.0, 4.0], + "id": 3, + "height": 2.4, # Float column + "happiness": None, + # Sadness missing intentionally + }, +] +texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas] + +DOCUMENTS = [ + Document(page_content=text, metadata=metadata) + for text, metadata in zip(texts, metadatas) +] + + +TYPE_1_FILTERING_TEST_CASES = [ + # These tests only involve equality checks + ( + {"id": 1}, + [1], + ), + # String field + ( + # check name + {"name": "adam"}, + [1], + ), + # Boolean fields + ( + {"is_active": True}, + [1, 3], + ), + ( + {"is_active": False}, + [2], + ), + # And semantics for top level filtering + ( + {"id": 1, "is_active": True}, + [1], + ), + ( + {"id": 1, "is_active": False}, + [], + ), +] + +TYPE_2_FILTERING_TEST_CASES = [ + # These involve equality checks and other operators + # like $ne, $gt, $gte, $lt, $lte, $not + ( + {"id": 1}, + [1], + ), + ( + {"id": {"$ne": 1}}, + [2, 3], + ), + ( + {"id": {"$gt": 1}}, + [2, 3], + ), + ( + {"id": {"$gte": 1}}, + [1, 2, 3], + ), + ( + {"id": {"$lt": 1}}, + [], + ), + ( + {"id": {"$lte": 1}}, + [1], + ), + # Repeat all the same tests with name (string column) + ( + {"name": "adam"}, + [1], + ), + ( + {"name": "bob"}, + [2], + ), + ( + {"name": {"$eq": "adam"}}, + [1], + ), + ( + {"name": {"$ne": "adam"}}, + [2, 3], + ), + # And also gt, gte, lt, lte relying on lexicographical ordering + ( + {"name": {"$gt": "jane"}}, + [], + ), + ( + {"name": {"$gte": "jane"}}, + [3], + ), + ( + {"name": {"$lt": "jane"}}, + [1, 2], + ), + ( + {"name": {"$lte": "jane"}}, + [1, 2, 3], + ), + ( + {"is_active": {"$eq": True}}, + [1, 3], + ), + ( + {"is_active": {"$ne": True}}, + [2], + ), + # Test float column. + ( + {"height": {"$gt": 5.0}}, + [1, 2], + ), + ( + {"height": {"$gte": 5.0}}, + [1, 2], + ), + ( + {"height": {"$lt": 5.0}}, + [3], + ), + ( + {"height": {"$lte": 5.8}}, + [2, 3], + ), +] + +TYPE_3_FILTERING_TEST_CASES = [ + # These involve usage of AND and OR operators + ( + {"$or": [{"id": 1}, {"id": 2}]}, + [1, 2], + ), + ( + {"$or": [{"id": 1}, {"name": "bob"}]}, + [1, 2], + ), + ( + {"$and": [{"id": 1}, {"id": 2}]}, + [], + ), + ( + {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, + [1, 2, 3], + ), +] + +TYPE_4_FILTERING_TEST_CASES = [ + # These involve special operators like $in, $nin, $between + # Test between + ( + {"id": {"$between": (1, 2)}}, + [1, 2], + ), + ( + {"id": {"$between": (1, 1)}}, + [1], + ), + ( + {"name": {"$in": ["adam", "bob"]}}, + [1, 2], + ), +] + +TYPE_5_FILTERING_TEST_CASES = [ + # These involve special operators like $like, $ilike that + # may be specified to certain databases. + ( + {"name": {"$like": "a%"}}, + [1], + ), + ( + {"name": {"$like": "%a%"}}, # adam and jane + [1, 3], + ), +] diff --git a/libs/partners/postgres/tests/integration_tests/test_vectorstore.py b/libs/partners/postgres/tests/integration_tests/test_vectorstore.py new file mode 100644 index 00000000000..2a89103d356 --- /dev/null +++ b/libs/partners/postgres/tests/integration_tests/test_vectorstore.py @@ -0,0 +1,505 @@ +"""Test PGVector functionality.""" + +import os +from typing import Any, Dict, Generator, List + +import pytest +import sqlalchemy +from langchain_core.documents import Document +from sqlalchemy.orm import Session + +from langchain_postgres.vectorstores import ( + SUPPORTED_OPERATORS, + PGVector, +) +from tests.integration_tests.fake_embeddings import FakeEmbeddings +from tests.integration_tests.fixtures.filtering_test_cases import ( + DOCUMENTS, + TYPE_1_FILTERING_TEST_CASES, + TYPE_2_FILTERING_TEST_CASES, + TYPE_3_FILTERING_TEST_CASES, + TYPE_4_FILTERING_TEST_CASES, + TYPE_5_FILTERING_TEST_CASES, +) + +# The connection string matches the default settings in the docker-compose file +# located in the root of the repository: [root]/docker/docker-compose.yml +# Non-standard ports are used to avoid conflicts with other local postgres +# instances. +# To spin up postgres with the pgvector extension: +# cd [root]/docker/docker-compose.yml +# docker compose up pgvector +CONNECTION_STRING = PGVector.connection_string_from_db_params( + driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg"), + host=os.environ.get("TEST_PGVECTOR_HOST", "localhost"), + port=int(os.environ.get("TEST_PGVECTOR_PORT", "6024")), + database=os.environ.get("TEST_PGVECTOR_DATABASE", "langchain"), + user=os.environ.get("TEST_PGVECTOR_USER", "langchain"), + password=os.environ.get("TEST_PGVECTOR_PASSWORD", "langchain"), +) + +ADA_TOKEN_COUNT = 1536 + + +class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] + + +def test_pgvector(pgvector: PGVector) -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_pgvector_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = PGVector.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_pgvector_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +def test_pgvector_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_pgvector_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + +def test_pgvector_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) + ] + + +def test_pgvector_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + +def test_pgvector_collection_with_metadata() -> None: + """Test end to end collection construction""" + pgvector = PGVector( + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding_function=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = Session(pgvector._create_engine()) + collection = pgvector.get_collection(session) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + +def test_pgvector_with_filter_in_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"IN": ["0", "2"]}} + ) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 0.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), + ] + + +def test_pgvector_with_filter_nin_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"NIN": ["1"]}} + ) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 0.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), + ] + + +def test_pgvector_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.delete(["1", "2"]) + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == ["3"] # type: ignore + + docsearch.delete(["2", "3"]) # Should not raise on missing ids + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == [] # type: ignore + + +def test_pgvector_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = docsearch.similarity_search_with_relevance_scores("foo", k=3) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), + (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), + ] + + +def test_pgvector_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.999}, + ) + output = retriever.get_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + +def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = retriever.get_relevant_documents("foo") + assert output == [] + + +def test_pgvector_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + +def test_pgvector_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) + assert output == [(Document(page_content="foo"), 0.0)] + + +def test_pgvector_with_custom_connection() -> None: + """Test construction using a custom connection.""" + texts = ["foo", "bar", "baz"] + engine = sqlalchemy.create_engine(CONNECTION_STRING) + with engine.connect() as connection: + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + connection=connection, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_pgvector_with_custom_engine_args() -> None: + """Test construction using custom engine arguments.""" + texts = ["foo", "bar", "baz"] + engine_args = { + "pool_size": 5, + "max_overflow": 10, + "pool_recycle": -1, + "pool_use_lifo": False, + "pool_pre_ping": False, + "pool_timeout": 30, + } + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + engine_args=engine_args, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +# We should reuse this test-case across other integrations +# Add database fixture using pytest +@pytest.fixture +def pgvector() -> Generator[PGVector, None, None]: + """Create a PGVector instance.""" + store = PGVector.from_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + # Do clean up + finally: + store.drop_tables() + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES[:1]) +def test_pgvector_with_with_metadata_filters_1( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_2( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_3( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_4( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +def test_pgvector_with_with_metadata_filters_5( + pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + +@pytest.mark.parametrize( + "invalid_filter", + [ + ["hello"], + { + "id": 2, + "$name": "foo", + }, + {"$or": {}}, + {"$and": {}}, + {"$between": {}}, + {"$eq": {}}, + ], +) +def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None: + """Verify that invalid filters raise an error.""" + with pytest.raises(ValueError): + pgvector._create_filter_clause(invalid_filter) + + +def test_validate_operators() -> None: + """Verify that all operators have been categorized.""" + assert sorted(SUPPORTED_OPERATORS) == [ + "$and", + "$between", + "$eq", + "$gt", + "$gte", + "$ilike", + "$in", + "$like", + "$lt", + "$lte", + "$ne", + "$nin", + "$or", + ]