mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-09 01:00:01 +00:00
This PR adds annotations in comunity package. Annotations are only strictly needed in subclasses of BaseModel for pydantic 2 compatibility. This PR adds some unnecessary annotations, but they're not bad to have regardless for documentation pages.
273 lines
9.5 KiB
Python
273 lines
9.5 KiB
Python
from typing import Any, Iterable, List, Optional, Tuple
|
|
from uuid import uuid4
|
|
|
|
import numpy as np
|
|
import requests
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.utils import get_from_env
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
|
|
|
|
class SemaDB(VectorStore):
|
|
"""`SemaDB` vector store.
|
|
|
|
This vector store is a wrapper around the SemaDB database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.vectorstores import SemaDB
|
|
|
|
db = SemaDB('mycollection', 768, embeddings, DistanceStrategy.COSINE)
|
|
|
|
"""
|
|
|
|
HOST: str = "semadb.p.rapidapi.com"
|
|
BASE_URL = "https://" + HOST
|
|
|
|
def __init__(
|
|
self,
|
|
collection_name: str,
|
|
vector_size: int,
|
|
embedding: Embeddings,
|
|
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
api_key: str = "",
|
|
):
|
|
"""initialize the SemaDB vector store."""
|
|
self.collection_name = collection_name
|
|
self.vector_size = vector_size
|
|
self.api_key = api_key or get_from_env("api_key", "SEMADB_API_KEY")
|
|
self._embedding = embedding
|
|
self.distance_strategy = distance_strategy
|
|
|
|
@property
|
|
def headers(self) -> dict:
|
|
"""Return the common headers."""
|
|
return {
|
|
"content-type": "application/json",
|
|
"X-RapidAPI-Key": self.api_key,
|
|
"X-RapidAPI-Host": SemaDB.HOST,
|
|
}
|
|
|
|
def _get_internal_distance_strategy(self) -> str:
|
|
"""Return the internal distance strategy."""
|
|
if self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
|
return "euclidean"
|
|
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
|
raise ValueError("Max inner product is not supported by SemaDB")
|
|
elif self.distance_strategy == DistanceStrategy.DOT_PRODUCT:
|
|
return "dot"
|
|
elif self.distance_strategy == DistanceStrategy.JACCARD:
|
|
raise ValueError("Max inner product is not supported by SemaDB")
|
|
elif self.distance_strategy == DistanceStrategy.COSINE:
|
|
return "cosine"
|
|
else:
|
|
raise ValueError(f"Unknown distance strategy {self.distance_strategy}")
|
|
|
|
def create_collection(self) -> bool:
|
|
"""Creates the corresponding collection in SemaDB."""
|
|
payload = {
|
|
"id": self.collection_name,
|
|
"vectorSize": self.vector_size,
|
|
"distanceMetric": self._get_internal_distance_strategy(),
|
|
}
|
|
response = requests.post(
|
|
SemaDB.BASE_URL + "/collections",
|
|
json=payload,
|
|
headers=self.headers,
|
|
)
|
|
return response.status_code == 200
|
|
|
|
def delete_collection(self) -> bool:
|
|
"""Deletes the corresponding collection in SemaDB."""
|
|
response = requests.delete(
|
|
SemaDB.BASE_URL + f"/collections/{self.collection_name}",
|
|
headers=self.headers,
|
|
)
|
|
return response.status_code == 200
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
batch_size: int = 1000,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Add texts to the vector store."""
|
|
if not isinstance(texts, list):
|
|
texts = list(texts)
|
|
embeddings = self._embedding.embed_documents(texts)
|
|
# Check dimensions
|
|
if len(embeddings[0]) != self.vector_size:
|
|
raise ValueError(
|
|
f"Embedding size mismatch {len(embeddings[0])} != {self.vector_size}"
|
|
)
|
|
# Normalise if needed
|
|
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
embed_matrix = np.array(embeddings)
|
|
embed_matrix = embed_matrix / np.linalg.norm(
|
|
embed_matrix, axis=1, keepdims=True
|
|
)
|
|
embeddings = embed_matrix.tolist()
|
|
# Create points
|
|
ids: List[str] = []
|
|
points = []
|
|
if metadatas is not None:
|
|
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
|
new_id = str(uuid4())
|
|
ids.append(new_id)
|
|
points.append(
|
|
{
|
|
"id": new_id,
|
|
"vector": embedding,
|
|
"metadata": {**metadata, **{"text": text}},
|
|
}
|
|
)
|
|
else:
|
|
for text, embedding in zip(texts, embeddings):
|
|
new_id = str(uuid4())
|
|
ids.append(new_id)
|
|
points.append(
|
|
{
|
|
"id": new_id,
|
|
"vector": embedding,
|
|
"metadata": {"text": text},
|
|
}
|
|
)
|
|
# Insert points in batches
|
|
for i in range(0, len(points), batch_size):
|
|
batch = points[i : i + batch_size]
|
|
response = requests.post(
|
|
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
|
|
json={"points": batch},
|
|
headers=self.headers,
|
|
)
|
|
if response.status_code != 200:
|
|
print("HERE--", batch) # noqa: T201
|
|
raise ValueError(f"Error adding points: {response.text}")
|
|
failed_ranges = response.json()["failedRanges"]
|
|
if len(failed_ranges) > 0:
|
|
raise ValueError(f"Error adding points: {failed_ranges}")
|
|
# Return ids
|
|
return ids
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
"""Return the embeddings."""
|
|
return self._embedding
|
|
|
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
|
"""Delete by vector ID or other criteria.
|
|
|
|
Args:
|
|
ids: List of ids to delete.
|
|
**kwargs: Other keyword arguments that subclasses might use.
|
|
|
|
Returns:
|
|
Optional[bool]: True if deletion is successful,
|
|
False otherwise, None if not implemented.
|
|
"""
|
|
payload = {
|
|
"ids": ids,
|
|
}
|
|
response = requests.delete(
|
|
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
|
|
json=payload,
|
|
headers=self.headers,
|
|
)
|
|
return response.status_code == 200 and len(response.json()["failedPoints"]) == 0
|
|
|
|
def _search_points(self, embedding: List[float], k: int = 4) -> List[dict]:
|
|
"""Search points."""
|
|
# Normalise if needed
|
|
if self.distance_strategy == DistanceStrategy.COSINE:
|
|
vec = np.array(embedding)
|
|
vec = vec / np.linalg.norm(vec)
|
|
embedding = vec.tolist()
|
|
# Perform search request
|
|
payload = {
|
|
"vector": embedding,
|
|
"limit": k,
|
|
}
|
|
response = requests.post(
|
|
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points/search",
|
|
json=payload,
|
|
headers=self.headers,
|
|
)
|
|
if response.status_code != 200:
|
|
raise ValueError(f"Error searching: {response.text}")
|
|
return response.json()["points"]
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Return docs most similar to query."""
|
|
query_embedding = self._embedding.embed_query(query)
|
|
return self.similarity_search_by_vector(query_embedding, k=k)
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Run similarity search with distance."""
|
|
query_embedding = self._embedding.embed_query(query)
|
|
points = self._search_points(query_embedding, k=k)
|
|
return [
|
|
(
|
|
Document(page_content=p["metadata"]["text"], metadata=p["metadata"]),
|
|
p["distance"],
|
|
)
|
|
for p in points
|
|
]
|
|
|
|
def similarity_search_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Return docs most similar to embedding vector.
|
|
|
|
Args:
|
|
embedding: Embedding to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
Returns:
|
|
List of Documents most similar to the query vector.
|
|
"""
|
|
points = self._search_points(embedding, k=k)
|
|
return [
|
|
Document(page_content=p["metadata"]["text"], metadata=p["metadata"])
|
|
for p in points
|
|
]
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
collection_name: str = "",
|
|
vector_size: int = 0,
|
|
api_key: str = "",
|
|
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
|
**kwargs: Any,
|
|
) -> "SemaDB":
|
|
"""Return VectorStore initialized from texts and embeddings."""
|
|
if not collection_name:
|
|
raise ValueError("Collection name must be provided")
|
|
if not vector_size:
|
|
raise ValueError("Vector size must be provided")
|
|
if not api_key:
|
|
raise ValueError("API key must be provided")
|
|
semadb = cls(
|
|
collection_name,
|
|
vector_size,
|
|
embedding,
|
|
distance_strategy=distance_strategy,
|
|
api_key=api_key,
|
|
)
|
|
if not semadb.create_collection():
|
|
raise ValueError("Error creating collection")
|
|
semadb.add_texts(texts, metadatas=metadatas)
|
|
return semadb
|