mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
Update VectorStore Class Method Typing (#2731)
Avoid using placeholder methods that only perform a `cast()` operation because the typing would otherwise be inferred to be the parent `VectorStore` class. This is unnecessary with TypeVar's.
This commit is contained in:
parent
446c3d586c
commit
0806951c07
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from typing import Any, Iterable, List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -210,7 +210,7 @@ class AtlasDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[AtlasDB],
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -270,7 +270,7 @@ class AtlasDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
cls: Type[AtlasDB],
|
||||
documents: List[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
@ -10,6 +10,8 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
VST = TypeVar("VST", bound="VectorStore")
|
||||
|
||||
|
||||
class VectorStore(ABC):
|
||||
"""Interface for vector stores."""
|
||||
@ -153,11 +155,11 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VectorStore:
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
@ -165,11 +167,11 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
cls,
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VectorStore:
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
@ -178,22 +180,22 @@ class VectorStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorStore:
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls,
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorStore:
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -269,7 +269,7 @@ class Chroma(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[Chroma],
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -307,7 +307,7 @@ class Chroma(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
cls: Type[Chroma],
|
||||
documents: List[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
|
@ -1,7 +1,10 @@
|
||||
"""VectorStore wrapper around a Postgres/PGVector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import sqlalchemy
|
||||
from pgvector.sqlalchemy import Vector
|
||||
@ -346,7 +349,7 @@ class PGVector(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[PGVector],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -355,7 +358,7 @@ class PGVector(VectorStore):
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "PGVector":
|
||||
) -> PGVector:
|
||||
"""
|
||||
Return VectorStore initialized from texts and embeddings.
|
||||
Postgres connection string is required
|
||||
@ -395,7 +398,7 @@ class PGVector(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
cls: Type[PGVector],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
@ -403,7 +406,7 @@ class PGVector(VectorStore):
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "PGVector":
|
||||
) -> PGVector:
|
||||
"""
|
||||
Return VectorStore initialized from documents and embeddings.
|
||||
Postgres connection string is required
|
||||
|
@ -1,7 +1,9 @@
|
||||
"""Wrapper around Qdrant vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -176,55 +178,9 @@ class Qdrant(VectorStore):
|
||||
for i in mmr_selected
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
location: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
port: Optional[int] = 6333,
|
||||
grpc_port: int = 6334,
|
||||
prefer_grpc: bool = False,
|
||||
https: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
host: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
**kwargs: Any,
|
||||
) -> "Qdrant":
|
||||
return cast(
|
||||
Qdrant,
|
||||
super().from_documents(
|
||||
documents,
|
||||
embedding,
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
prefer_grpc=prefer_grpc,
|
||||
https=https,
|
||||
api_key=api_key,
|
||||
prefix=prefix,
|
||||
timeout=timeout,
|
||||
host=host,
|
||||
path=path,
|
||||
collection_name=collection_name,
|
||||
distance_func=distance_func,
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -244,7 +200,7 @@ class Qdrant(VectorStore):
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
**kwargs: Any,
|
||||
) -> "Qdrant":
|
||||
) -> Qdrant:
|
||||
"""Construct Qdrant wrapper from raw documents.
|
||||
|
||||
Args:
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
@ -227,7 +227,7 @@ class Redis(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[Redis],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
@ -104,11 +104,11 @@ class Weaviate(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[Weaviate],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorStore:
|
||||
) -> Weaviate:
|
||||
"""Not implemented for Weaviate yet."""
|
||||
raise NotImplementedError("weaviate does not currently support `from_texts`.")
|
||||
|
Loading…
Reference in New Issue
Block a user