From b6779b91eaf22773ee7f9336296dee77ccfe284f Mon Sep 17 00:00:00 2001 From: Dev 2049 Date: Tue, 30 May 2023 14:46:44 -0700 Subject: [PATCH] rfc --- langchain/docstore/base.py | 155 ++++++++++++++++++++++++++++++- langchain/vectorstores/chroma.py | 45 +++++++-- 2 files changed, 188 insertions(+), 12 deletions(-) diff --git a/langchain/docstore/base.py b/langchain/docstore/base.py index 4a91680c732..bfc6a6e6c82 100644 --- a/langchain/docstore/base.py +++ b/langchain/docstore/base.py @@ -1,12 +1,29 @@ """Interface to access to place that stores documents.""" +from __future__ import annotations + +import json +import uuid from abc import ABC, abstractmethod -from typing import Dict, Union +from pathlib import Path +from typing import ( + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, +) from langchain.docstore.document import Document +UID_TYPE = TypeVar("UID_TYPE") + class Docstore(ABC): - """Interface to access to place that stores documents.""" + """Interface to place that stores documents.""" @abstractmethod def search(self, search: str) -> Union[str, Document]: @@ -17,9 +34,141 @@ class Docstore(ABC): """ -class AddableMixin(ABC): +class AddableMixin(object, ABC): """Mixin class that supports adding texts.""" @abstractmethod def add(self, texts: Dict[str, Document]) -> None: """Add more documents.""" + + +class DocManager(ABC, Generic[UID_TYPE]): + def add(self, doc: Document) -> UID_TYPE: + if self.contains_doc(doc): + raise ValueError + uid = self.generate_uid() + self._add(doc, uid) + return uid + + def lazy_add_docs(self, docs: Iterator[Document]) -> Iterator[UID_TYPE]: + for doc in docs: + yield self.add(doc) + + def add_docs(self, docs: Sequence[Document]) -> List[UID_TYPE]: + return list(self.lazy_add_docs(docs)) + + def add_text(self, text: str, metadata: Optional[dict] = None) -> UID_TYPE: + _metadata = metadata or {} + return self.add(Document(page_content=text, metadata=_metadata)) + + def lazy_add_texts( + self, texts: Iterator[str], metadatas: Optional[Iterator[dict]] = None + ) -> Iterator[UID_TYPE]: + _metadatas = metadatas or ({} for _ in texts) + for text, metadata in zip(texts, _metadatas): + yield self.add_text(text, metadata=metadata) + + def add_texts( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[UID_TYPE]: + return list(self.lazy_add_texts(texts, metadatas=metadatas)) + + @abstractmethod + def _add(self, doc: Document, uid: UID_TYPE) -> None: + """""" + + def delete(self, uid: UID_TYPE) -> None: + """""" + if not self.get(uid): + raise ValueError + + return self._delete(uid) + + @abstractmethod + def _delete(self, uid: UID_TYPE) -> None: + """""" + + @abstractmethod + def contains_doc(self, doc: Document) -> bool: + """""" + + @abstractmethod + def get_doc_id(self, doc: Document) -> UID_TYPE: + """""" + + @abstractmethod + def get(self, uid: UID_TYPE) -> Document: + """""" + + @abstractmethod + def generate_uid(self) -> UID_TYPE: + """""" + uid = self._generate_uid() + if self.get(uid): + raise ValueError + return uid + + @abstractmethod + def _generate_uid(self) -> UID_TYPE: + """""" + + def persist(self, path: Union[str, Path]) -> None: + raise NotImplementedError + + @classmethod + def load(cls, path: Union[str, Path]) -> DocManager: + raise NotImplementedError + + +class SimpleDocManager(DocManager[str]): + def __init__( + self, + uid_doc_map: Optional[Dict[str, Document]] = None, + doc_uid_map: Optional[Dict[str, str]] = None, + ): + self.uid_doc_map = uid_doc_map or {} + self.doc_uid_map = doc_uid_map or {} + if len(self.uid_doc_map) != len(self.doc_uid_map): + raise ValueError + + def _add(self, doc: Document, uid: str) -> None: + self.uid_doc_map[uid] = doc + self.doc_uid_map[self.serialize(doc)] = uid + + def _delete(self, uid: str) -> None: + doc = self.uid_doc_map[uid] + del self.doc_uid_map[self.serialize(doc)] + del self.uid_doc_map[uid] + + def contains_doc(self, doc: Document) -> bool: + return self.serialize(doc) in self.doc_uid_map + + def get_doc_id(self, doc: Document) -> str: + return self.doc_uid_map[self.serialize(doc)] + + def serialize(self, doc: Document) -> str: + # Assumes metadata is JSON-serializable. + return json.dumps(doc.dict(), sort_keys=True) + + def _generate_uid(self) -> str: + return str(uuid.uuid4()) + + def persist(self, path: Union[str, Path]) -> None: + uid_doc_map = { + uid: self.serialize(doc) for uid, doc in self.uid_doc_map.items() + } + with open(path, "w") as f: + json.dump(uid_doc_map, f, indent=2) + + @classmethod + def load(cls, path: Union[str, Path]) -> SimpleDocManager: + with open(path, "r") as f: + serialized_uid_doc_map = json.load(f) + uid_doc_map = { + uid: Document( + page_content=doc_dict["page_content"], metadata=doc_dict["metadata"] + ) + for uid, doc_dict in serialized_uid_doc_map.items() + } + doc_uid_map = {v: k for k, v in serialized_uid_doc_map.items()} + return cls(uid_doc_map=uid_doc_map, doc_uid_map=doc_uid_map) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index b10a81ed9ba..4d9fa7f224d 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty import numpy as np +from langchain.docstore.base import DocManager from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.utils import xor_args @@ -63,6 +64,7 @@ class Chroma(VectorStore): client_settings: Optional[chromadb.config.Settings] = None, collection_metadata: Optional[Dict] = None, client: Optional[chromadb.Client] = None, + doc_manager: Optional[DocManager] = None, ) -> None: """Initialize with Chroma client.""" try: @@ -97,6 +99,7 @@ class Chroma(VectorStore): else None, metadata=collection_metadata, ) + self._doc_manager = doc_manager @xor_args(("query_texts", "query_embeddings")) def __query_collection( @@ -152,6 +155,10 @@ class Chroma(VectorStore): List[str]: List of IDs of the added texts. """ # TODO: Handle the case where the user doesn't provide ids on the Collection + if self._doc_manager and ids: + raise ValueError + elif not ids: + ids = self._doc_manager.add_texts(texts, metadatas=metadatas) if ids is None: ids = [str(uuid.uuid1()) for _ in texts] embeddings = None @@ -346,20 +353,40 @@ class Chroma(VectorStore): document_id (str): ID of the document to update. document (Document): Document to update. """ - text = document.page_content - metadata = document.metadata + return self.update_existing_documents([document_id], [document]) + + def update_existing_documents( + self, document_ids: List[str], documents: List[Document] + ) -> None: + """Update a document in the collection.""" + texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) if self._embedding_function is None: raise ValueError( "For update, you must specify an embedding function on creation." ) - embeddings = self._embedding_function.embed_documents(list(text)) + embeddings = self._embedding_function.embed_documents(list(texts)) + self._collection.update(document_ids, embeddings, list(texts), list(metadatas)) - self._collection.update( - ids=[document_id], - embeddings=[embeddings[0]], - documents=[text], - metadatas=[metadata], - ) + def update(self, documents: List[Document]) -> List[str]: + if self._doc_manager: + has_id = [self._doc_manager.contains_doc(doc) for doc in documents] + old_ids = [ + self._doc_manager.get_doc_id(d) for h, d in zip(has_id, documents) if h + ] + new_ids = self.add_documents( + [d for h, d in zip(has_id, documents) if not h] + ) + all_ids = [] + old_idx, new_idx = 0, 0 + for i, is_old in enumerate(has_id): + if is_old: + all_ids.append(old_ids[old_idx]) + old_idx += 1 + else: + all_ids.append(new_ids[new_idx]) + new_idx += 1 + else: + return self.add_documents(documents) @classmethod def from_texts(