Compare commits

...

5 Commits

Author SHA1 Message Date
Dev 2049
bda5b0c050 nit 2023-05-30 14:58:16 -07:00
Dev 2049
37dd0304de lint 2023-05-30 14:52:09 -07:00
Dev 2049
562ee92e70 fmt 2023-05-30 14:51:03 -07:00
Dev 2049
7912e13522 rfc 2023-05-30 14:50:40 -07:00
Dev 2049
b6779b91ea rfc 2023-05-30 14:46:44 -07:00
2 changed files with 194 additions and 15 deletions

View File

@@ -1,12 +1,28 @@
"""Interface to access to place that stores documents.""" """Interface to access to place that stores documents."""
from __future__ import annotations
import json
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Union from pathlib import Path
from typing import (
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
UID_TYPE = TypeVar("UID_TYPE")
class Docstore(ABC): class Docstore(ABC):
"""Interface to access to place that stores documents.""" """Interface to place that stores documents."""
@abstractmethod @abstractmethod
def search(self, search: str) -> Union[str, Document]: def search(self, search: str) -> Union[str, Document]:
@@ -23,3 +39,135 @@ class AddableMixin(ABC):
@abstractmethod @abstractmethod
def add(self, texts: Dict[str, Document]) -> None: def add(self, texts: Dict[str, Document]) -> None:
"""Add more documents.""" """Add more documents."""
class DocManager(ABC, Generic[UID_TYPE]):
def add(self, doc: Document) -> UID_TYPE:
if self.contains_doc(doc):
raise ValueError
uid = self.generate_uid()
self._add(doc, uid)
return uid
def lazy_add_docs(self, docs: Iterable[Document]) -> Iterator[UID_TYPE]:
for doc in docs:
yield self.add(doc)
def add_docs(self, docs: Iterable[Document]) -> List[UID_TYPE]:
return list(self.lazy_add_docs(docs))
def add_text(self, text: str, metadata: Optional[dict] = None) -> UID_TYPE:
_metadata = metadata or {}
return self.add(Document(page_content=text, metadata=_metadata))
def lazy_add_texts(
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
) -> Iterator[UID_TYPE]:
_metadatas = metadatas or ({} for _ in texts)
for text, metadata in zip(texts, _metadatas):
yield self.add_text(text, metadata=metadata)
def add_texts(
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
) -> List[UID_TYPE]:
return list(self.lazy_add_texts(texts, metadatas=metadatas))
@abstractmethod
def _add(self, doc: Document, uid: UID_TYPE) -> None:
""""""
def delete(self, uid: UID_TYPE) -> None:
""""""
if not self.get(uid):
raise ValueError
return self._delete(uid)
@abstractmethod
def _delete(self, uid: UID_TYPE) -> None:
""""""
@abstractmethod
def contains_doc(self, doc: Document) -> bool:
""""""
@abstractmethod
def get_doc_id(self, doc: Document) -> UID_TYPE:
""""""
@abstractmethod
def get(self, uid: UID_TYPE) -> Document:
""""""
@abstractmethod
def generate_uid(self) -> UID_TYPE:
""""""
uid = self._generate_uid()
if self.get(uid):
raise ValueError
return uid
@abstractmethod
def _generate_uid(self) -> UID_TYPE:
""""""
def persist(self, path: Union[str, Path]) -> None:
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> DocManager:
raise NotImplementedError
class SimpleDocManager(DocManager[str]):
def __init__(
self,
uid_doc_map: Optional[Dict[str, Document]] = None,
doc_uid_map: Optional[Dict[str, str]] = None,
):
self.uid_doc_map = uid_doc_map or {}
self.doc_uid_map = doc_uid_map or {}
if len(self.uid_doc_map) != len(self.doc_uid_map):
raise ValueError
def _add(self, doc: Document, uid: str) -> None:
self.uid_doc_map[uid] = doc
self.doc_uid_map[self.serialize(doc)] = uid
def _delete(self, uid: str) -> None:
doc = self.uid_doc_map[uid]
del self.doc_uid_map[self.serialize(doc)]
del self.uid_doc_map[uid]
def contains_doc(self, doc: Document) -> bool:
return self.serialize(doc) in self.doc_uid_map
def get_doc_id(self, doc: Document) -> str:
return self.doc_uid_map[self.serialize(doc)]
def serialize(self, doc: Document) -> str:
# Assumes metadata is JSON-serializable.
return json.dumps(doc.dict(), sort_keys=True)
def _generate_uid(self) -> str:
return str(uuid.uuid4())
def persist(self, path: Union[str, Path]) -> None:
uid_doc_map = {
uid: self.serialize(doc) for uid, doc in self.uid_doc_map.items()
}
with open(path, "w") as f:
json.dump(uid_doc_map, f, indent=2)
@classmethod
def load(cls, path: Union[str, Path]) -> SimpleDocManager:
with open(path, "r") as f:
serialized_uid_doc_map = json.load(f)
uid_doc_map = {
uid: Document(
page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]
)
for uid, doc_dict in serialized_uid_doc_map.items()
}
doc_uid_map = {v: k for k, v in serialized_uid_doc_map.items()}
return cls(uid_doc_map=uid_doc_map, doc_uid_map=doc_uid_map)

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty
import numpy as np import numpy as np
from langchain.docstore.base import DocManager
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import xor_args from langchain.utils import xor_args
@@ -63,6 +64,7 @@ class Chroma(VectorStore):
client_settings: Optional[chromadb.config.Settings] = None, client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None, collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.Client] = None, client: Optional[chromadb.Client] = None,
doc_manager: Optional[DocManager] = None,
) -> None: ) -> None:
"""Initialize with Chroma client.""" """Initialize with Chroma client."""
try: try:
@@ -97,6 +99,7 @@ class Chroma(VectorStore):
else None, else None,
metadata=collection_metadata, metadata=collection_metadata,
) )
self._doc_manager = doc_manager
@xor_args(("query_texts", "query_embeddings")) @xor_args(("query_texts", "query_embeddings"))
def __query_collection( def __query_collection(
@@ -152,15 +155,21 @@ class Chroma(VectorStore):
List[str]: List of IDs of the added texts. List[str]: List of IDs of the added texts.
""" """
# TODO: Handle the case where the user doesn't provide ids on the Collection # TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None: if self._doc_manager and ids:
ids = [str(uuid.uuid1()) for _ in texts] raise ValueError
elif self._doc_manager:
_ids = self._doc_manager.add_texts(texts, metadatas=metadatas)
elif ids is None:
_ids = [str(uuid.uuid1()) for _ in texts]
else:
_ids = ids
embeddings = None embeddings = None
if self._embedding_function is not None: if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(list(texts)) embeddings = self._embedding_function.embed_documents(list(texts))
self._collection.add( self._collection.add(
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids metadatas=metadatas, embeddings=embeddings, documents=texts, ids=_ids
) )
return ids return _ids
def similarity_search( def similarity_search(
self, self,
@@ -346,20 +355,42 @@ class Chroma(VectorStore):
document_id (str): ID of the document to update. document_id (str): ID of the document to update.
document (Document): Document to update. document (Document): Document to update.
""" """
text = document.page_content return self.update_existing_documents([document_id], [document])
metadata = document.metadata
def update_existing_documents(
self, document_ids: List[str], documents: List[Document]
) -> None:
"""Update a document in the collection."""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
texts_list = list(texts)
if self._embedding_function is None: if self._embedding_function is None:
raise ValueError( raise ValueError(
"For update, you must specify an embedding function on creation." "For update, you must specify an embedding function on creation."
) )
embeddings = self._embedding_function.embed_documents(list(text)) embeddings = self._embedding_function.embed_documents(texts_list)
self._collection.update(document_ids, embeddings, texts_list, list(metadatas))
self._collection.update( def update(self, documents: List[Document]) -> List[str]:
ids=[document_id], if self._doc_manager:
embeddings=[embeddings[0]], has_id = [self._doc_manager.contains_doc(doc) for doc in documents]
documents=[text], old_ids = [
metadatas=[metadata], self._doc_manager.get_doc_id(d) for h, d in zip(has_id, documents) if h
) ]
new_ids = self.add_documents(
[d for h, d in zip(has_id, documents) if not h]
)
all_ids = []
old_idx, new_idx = 0, 0
for i, is_old in enumerate(has_id):
if is_old:
all_ids.append(old_ids[old_idx])
old_idx += 1
else:
all_ids.append(new_ids[new_idx])
new_idx += 1
return all_ids
else:
return self.add_documents(documents)
@classmethod @classmethod
def from_texts( def from_texts(