Compare commits

...

5 Commits

Author SHA1 Message Date
Dev 2049
bda5b0c050 nit 2023-05-30 14:58:16 -07:00
Dev 2049
37dd0304de lint 2023-05-30 14:52:09 -07:00
Dev 2049
562ee92e70 fmt 2023-05-30 14:51:03 -07:00
Dev 2049
7912e13522 rfc 2023-05-30 14:50:40 -07:00
Dev 2049
b6779b91ea rfc 2023-05-30 14:46:44 -07:00
2 changed files with 194 additions and 15 deletions

View File

@@ -1,12 +1,28 @@
"""Interface to access to place that stores documents."""
from __future__ import annotations
import json
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Union
from pathlib import Path
from typing import (
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
)
from langchain.docstore.document import Document
UID_TYPE = TypeVar("UID_TYPE")
class Docstore(ABC):
"""Interface to access to place that stores documents."""
"""Interface to place that stores documents."""
@abstractmethod
def search(self, search: str) -> Union[str, Document]:
@@ -23,3 +39,135 @@ class AddableMixin(ABC):
@abstractmethod
def add(self, texts: Dict[str, Document]) -> None:
"""Add more documents."""
class DocManager(ABC, Generic[UID_TYPE]):
def add(self, doc: Document) -> UID_TYPE:
if self.contains_doc(doc):
raise ValueError
uid = self.generate_uid()
self._add(doc, uid)
return uid
def lazy_add_docs(self, docs: Iterable[Document]) -> Iterator[UID_TYPE]:
for doc in docs:
yield self.add(doc)
def add_docs(self, docs: Iterable[Document]) -> List[UID_TYPE]:
return list(self.lazy_add_docs(docs))
def add_text(self, text: str, metadata: Optional[dict] = None) -> UID_TYPE:
_metadata = metadata or {}
return self.add(Document(page_content=text, metadata=_metadata))
def lazy_add_texts(
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
) -> Iterator[UID_TYPE]:
_metadatas = metadatas or ({} for _ in texts)
for text, metadata in zip(texts, _metadatas):
yield self.add_text(text, metadata=metadata)
def add_texts(
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
) -> List[UID_TYPE]:
return list(self.lazy_add_texts(texts, metadatas=metadatas))
@abstractmethod
def _add(self, doc: Document, uid: UID_TYPE) -> None:
""""""
def delete(self, uid: UID_TYPE) -> None:
""""""
if not self.get(uid):
raise ValueError
return self._delete(uid)
@abstractmethod
def _delete(self, uid: UID_TYPE) -> None:
""""""
@abstractmethod
def contains_doc(self, doc: Document) -> bool:
""""""
@abstractmethod
def get_doc_id(self, doc: Document) -> UID_TYPE:
""""""
@abstractmethod
def get(self, uid: UID_TYPE) -> Document:
""""""
@abstractmethod
def generate_uid(self) -> UID_TYPE:
""""""
uid = self._generate_uid()
if self.get(uid):
raise ValueError
return uid
@abstractmethod
def _generate_uid(self) -> UID_TYPE:
""""""
def persist(self, path: Union[str, Path]) -> None:
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> DocManager:
raise NotImplementedError
class SimpleDocManager(DocManager[str]):
def __init__(
self,
uid_doc_map: Optional[Dict[str, Document]] = None,
doc_uid_map: Optional[Dict[str, str]] = None,
):
self.uid_doc_map = uid_doc_map or {}
self.doc_uid_map = doc_uid_map or {}
if len(self.uid_doc_map) != len(self.doc_uid_map):
raise ValueError
def _add(self, doc: Document, uid: str) -> None:
self.uid_doc_map[uid] = doc
self.doc_uid_map[self.serialize(doc)] = uid
def _delete(self, uid: str) -> None:
doc = self.uid_doc_map[uid]
del self.doc_uid_map[self.serialize(doc)]
del self.uid_doc_map[uid]
def contains_doc(self, doc: Document) -> bool:
return self.serialize(doc) in self.doc_uid_map
def get_doc_id(self, doc: Document) -> str:
return self.doc_uid_map[self.serialize(doc)]
def serialize(self, doc: Document) -> str:
# Assumes metadata is JSON-serializable.
return json.dumps(doc.dict(), sort_keys=True)
def _generate_uid(self) -> str:
return str(uuid.uuid4())
def persist(self, path: Union[str, Path]) -> None:
uid_doc_map = {
uid: self.serialize(doc) for uid, doc in self.uid_doc_map.items()
}
with open(path, "w") as f:
json.dump(uid_doc_map, f, indent=2)
@classmethod
def load(cls, path: Union[str, Path]) -> SimpleDocManager:
with open(path, "r") as f:
serialized_uid_doc_map = json.load(f)
uid_doc_map = {
uid: Document(
page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]
)
for uid, doc_dict in serialized_uid_doc_map.items()
}
doc_uid_map = {v: k for k, v in serialized_uid_doc_map.items()}
return cls(uid_doc_map=uid_doc_map, doc_uid_map=doc_uid_map)

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty
import numpy as np
from langchain.docstore.base import DocManager
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import xor_args
@@ -63,6 +64,7 @@ class Chroma(VectorStore):
client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.Client] = None,
doc_manager: Optional[DocManager] = None,
) -> None:
"""Initialize with Chroma client."""
try:
@@ -97,6 +99,7 @@ class Chroma(VectorStore):
else None,
metadata=collection_metadata,
)
self._doc_manager = doc_manager
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
@@ -152,15 +155,21 @@ class Chroma(VectorStore):
List[str]: List of IDs of the added texts.
"""
# TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
if self._doc_manager and ids:
raise ValueError
elif self._doc_manager:
_ids = self._doc_manager.add_texts(texts, metadatas=metadatas)
elif ids is None:
_ids = [str(uuid.uuid1()) for _ in texts]
else:
_ids = ids
embeddings = None
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(list(texts))
self._collection.add(
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=_ids
)
return ids
return _ids
def similarity_search(
self,
@@ -346,20 +355,42 @@ class Chroma(VectorStore):
document_id (str): ID of the document to update.
document (Document): Document to update.
"""
text = document.page_content
metadata = document.metadata
return self.update_existing_documents([document_id], [document])
def update_existing_documents(
self, document_ids: List[str], documents: List[Document]
) -> None:
"""Update a document in the collection."""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
texts_list = list(texts)
if self._embedding_function is None:
raise ValueError(
"For update, you must specify an embedding function on creation."
)
embeddings = self._embedding_function.embed_documents(list(text))
embeddings = self._embedding_function.embed_documents(texts_list)
self._collection.update(document_ids, embeddings, texts_list, list(metadatas))
self._collection.update(
ids=[document_id],
embeddings=[embeddings[0]],
documents=[text],
metadatas=[metadata],
)
def update(self, documents: List[Document]) -> List[str]:
if self._doc_manager:
has_id = [self._doc_manager.contains_doc(doc) for doc in documents]
old_ids = [
self._doc_manager.get_doc_id(d) for h, d in zip(has_id, documents) if h
]
new_ids = self.add_documents(
[d for h, d in zip(has_id, documents) if not h]
)
all_ids = []
old_idx, new_idx = 0, 0
for i, is_old in enumerate(has_id):
if is_old:
all_ids.append(old_ids[old_idx])
old_idx += 1
else:
all_ids.append(new_ids[new_idx])
new_idx += 1
return all_ids
else:
return self.add_documents(documents)
@classmethod
def from_texts(