mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
Update VectorStore Class Method Typing (#2731)
Avoid using placeholder methods that only perform a `cast()` operation because the typing would otherwise be inferred to be the parent `VectorStore` class. This is unnecessary with TypeVar's.
This commit is contained in:
parent
446c3d586c
commit
0806951c07
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ class AtlasDB(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[AtlasDB],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -270,7 +270,7 @@ class AtlasDB(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
cls,
|
cls: Type[AtlasDB],
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
@ -10,6 +10,8 @@ from langchain.docstore.document import Document
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import BaseRetriever
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
|
VST = TypeVar("VST", bound="VectorStore")
|
||||||
|
|
||||||
|
|
||||||
class VectorStore(ABC):
|
class VectorStore(ABC):
|
||||||
"""Interface for vector stores."""
|
"""Interface for vector stores."""
|
||||||
@ -153,11 +155,11 @@ class VectorStore(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
cls,
|
cls: Type[VST],
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VectorStore:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from documents and embeddings."""
|
"""Return VectorStore initialized from documents and embeddings."""
|
||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
metadatas = [d.metadata for d in documents]
|
metadatas = [d.metadata for d in documents]
|
||||||
@ -165,11 +167,11 @@ class VectorStore(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_documents(
|
async def afrom_documents(
|
||||||
cls,
|
cls: Type[VST],
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VectorStore:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from documents and embeddings."""
|
"""Return VectorStore initialized from documents and embeddings."""
|
||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
metadatas = [d.metadata for d in documents]
|
metadatas = [d.metadata for d in documents]
|
||||||
@ -178,22 +180,22 @@ class VectorStore(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[VST],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VectorStore:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_texts(
|
async def afrom_texts(
|
||||||
cls,
|
cls: Type[VST],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VectorStore:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -269,7 +269,7 @@ class Chroma(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[Chroma],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -307,7 +307,7 @@ class Chroma(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
cls,
|
cls: Type[Chroma],
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
"""VectorStore wrapper around a Postgres/PGVector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
@ -346,7 +349,7 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[PGVector],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -355,7 +358,7 @@ class PGVector(VectorStore):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "PGVector":
|
) -> PGVector:
|
||||||
"""
|
"""
|
||||||
Return VectorStore initialized from texts and embeddings.
|
Return VectorStore initialized from texts and embeddings.
|
||||||
Postgres connection string is required
|
Postgres connection string is required
|
||||||
@ -395,7 +398,7 @@ class PGVector(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_documents(
|
def from_documents(
|
||||||
cls,
|
cls: Type[PGVector],
|
||||||
documents: List[Document],
|
documents: List[Document],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
@ -403,7 +406,7 @@ class PGVector(VectorStore):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
pre_delete_collection: bool = False,
|
pre_delete_collection: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "PGVector":
|
) -> PGVector:
|
||||||
"""
|
"""
|
||||||
Return VectorStore initialized from documents and embeddings.
|
Return VectorStore initialized from documents and embeddings.
|
||||||
Postgres connection string is required
|
Postgres connection string is required
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
"""Wrapper around Qdrant vector database."""
|
"""Wrapper around Qdrant vector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -176,55 +178,9 @@ class Qdrant(VectorStore):
|
|||||||
for i in mmr_selected
|
for i in mmr_selected
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_documents(
|
|
||||||
cls,
|
|
||||||
documents: List[Document],
|
|
||||||
embedding: Embeddings,
|
|
||||||
location: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
port: Optional[int] = 6333,
|
|
||||||
grpc_port: int = 6334,
|
|
||||||
prefer_grpc: bool = False,
|
|
||||||
https: Optional[bool] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
host: Optional[str] = None,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
collection_name: Optional[str] = None,
|
|
||||||
distance_func: str = "Cosine",
|
|
||||||
content_payload_key: str = CONTENT_KEY,
|
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> "Qdrant":
|
|
||||||
return cast(
|
|
||||||
Qdrant,
|
|
||||||
super().from_documents(
|
|
||||||
documents,
|
|
||||||
embedding,
|
|
||||||
location=location,
|
|
||||||
url=url,
|
|
||||||
port=port,
|
|
||||||
grpc_port=grpc_port,
|
|
||||||
prefer_grpc=prefer_grpc,
|
|
||||||
https=https,
|
|
||||||
api_key=api_key,
|
|
||||||
prefix=prefix,
|
|
||||||
timeout=timeout,
|
|
||||||
host=host,
|
|
||||||
path=path,
|
|
||||||
collection_name=collection_name,
|
|
||||||
distance_func=distance_func,
|
|
||||||
content_payload_key=content_payload_key,
|
|
||||||
metadata_payload_key=metadata_payload_key,
|
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[Qdrant],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@ -244,7 +200,7 @@ class Qdrant(VectorStore):
|
|||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: str = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: str = METADATA_KEY,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "Qdrant":
|
) -> Qdrant:
|
||||||
"""Construct Qdrant wrapper from raw documents.
|
"""Construct Qdrant wrapper from raw documents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
@ -227,7 +227,7 @@ class Redis(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[Redis],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Wrapper around weaviate vector database."""
|
"""Wrapper around weaviate vector database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional, Type
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@ -104,11 +104,11 @@ class Weaviate(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls: Type[Weaviate],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VectorStore:
|
) -> Weaviate:
|
||||||
"""Not implemented for Weaviate yet."""
|
"""Not implemented for Weaviate yet."""
|
||||||
raise NotImplementedError("weaviate does not currently support `from_texts`.")
|
raise NotImplementedError("weaviate does not currently support `from_texts`.")
|
||||||
|
Loading…
Reference in New Issue
Block a user