mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
rfc
This commit is contained in:
@@ -1,12 +1,29 @@
|
||||
"""Interface to access to place that stores documents."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
UID_TYPE = TypeVar("UID_TYPE")
|
||||
|
||||
|
||||
class Docstore(ABC):
|
||||
"""Interface to access to place that stores documents."""
|
||||
"""Interface to place that stores documents."""
|
||||
|
||||
@abstractmethod
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
@@ -17,9 +34,141 @@ class Docstore(ABC):
|
||||
"""
|
||||
|
||||
|
||||
class AddableMixin(ABC):
|
||||
class AddableMixin(object, ABC):
|
||||
"""Mixin class that supports adding texts."""
|
||||
|
||||
@abstractmethod
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add more documents."""
|
||||
|
||||
|
||||
class DocManager(ABC, Generic[UID_TYPE]):
|
||||
def add(self, doc: Document) -> UID_TYPE:
|
||||
if self.contains_doc(doc):
|
||||
raise ValueError
|
||||
uid = self.generate_uid()
|
||||
self._add(doc, uid)
|
||||
return uid
|
||||
|
||||
def lazy_add_docs(self, docs: Iterator[Document]) -> Iterator[UID_TYPE]:
|
||||
for doc in docs:
|
||||
yield self.add(doc)
|
||||
|
||||
def add_docs(self, docs: Sequence[Document]) -> List[UID_TYPE]:
|
||||
return list(self.lazy_add_docs(docs))
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[dict] = None) -> UID_TYPE:
|
||||
_metadata = metadata or {}
|
||||
return self.add(Document(page_content=text, metadata=_metadata))
|
||||
|
||||
def lazy_add_texts(
|
||||
self, texts: Iterator[str], metadatas: Optional[Iterator[dict]] = None
|
||||
) -> Iterator[UID_TYPE]:
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
for text, metadata in zip(texts, _metadatas):
|
||||
yield self.add_text(text, metadata=metadata)
|
||||
|
||||
def add_texts(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> List[UID_TYPE]:
|
||||
return list(self.lazy_add_texts(texts, metadatas=metadatas))
|
||||
|
||||
@abstractmethod
|
||||
def _add(self, doc: Document, uid: UID_TYPE) -> None:
|
||||
""""""
|
||||
|
||||
def delete(self, uid: UID_TYPE) -> None:
|
||||
""""""
|
||||
if not self.get(uid):
|
||||
raise ValueError
|
||||
|
||||
return self._delete(uid)
|
||||
|
||||
@abstractmethod
|
||||
def _delete(self, uid: UID_TYPE) -> None:
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def contains_doc(self, doc: Document) -> bool:
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def get_doc_id(self, doc: Document) -> UID_TYPE:
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, uid: UID_TYPE) -> Document:
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def generate_uid(self) -> UID_TYPE:
|
||||
""""""
|
||||
uid = self._generate_uid()
|
||||
if self.get(uid):
|
||||
raise ValueError
|
||||
return uid
|
||||
|
||||
@abstractmethod
|
||||
def _generate_uid(self) -> UID_TYPE:
|
||||
""""""
|
||||
|
||||
def persist(self, path: Union[str, Path]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> DocManager:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleDocManager(DocManager[str]):
|
||||
def __init__(
|
||||
self,
|
||||
uid_doc_map: Optional[Dict[str, Document]] = None,
|
||||
doc_uid_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.uid_doc_map = uid_doc_map or {}
|
||||
self.doc_uid_map = doc_uid_map or {}
|
||||
if len(self.uid_doc_map) != len(self.doc_uid_map):
|
||||
raise ValueError
|
||||
|
||||
def _add(self, doc: Document, uid: str) -> None:
|
||||
self.uid_doc_map[uid] = doc
|
||||
self.doc_uid_map[self.serialize(doc)] = uid
|
||||
|
||||
def _delete(self, uid: str) -> None:
|
||||
doc = self.uid_doc_map[uid]
|
||||
del self.doc_uid_map[self.serialize(doc)]
|
||||
del self.uid_doc_map[uid]
|
||||
|
||||
def contains_doc(self, doc: Document) -> bool:
|
||||
return self.serialize(doc) in self.doc_uid_map
|
||||
|
||||
def get_doc_id(self, doc: Document) -> str:
|
||||
return self.doc_uid_map[self.serialize(doc)]
|
||||
|
||||
def serialize(self, doc: Document) -> str:
|
||||
# Assumes metadata is JSON-serializable.
|
||||
return json.dumps(doc.dict(), sort_keys=True)
|
||||
|
||||
def _generate_uid(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def persist(self, path: Union[str, Path]) -> None:
|
||||
uid_doc_map = {
|
||||
uid: self.serialize(doc) for uid, doc in self.uid_doc_map.items()
|
||||
}
|
||||
with open(path, "w") as f:
|
||||
json.dump(uid_doc_map, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> SimpleDocManager:
|
||||
with open(path, "r") as f:
|
||||
serialized_uid_doc_map = json.load(f)
|
||||
uid_doc_map = {
|
||||
uid: Document(
|
||||
page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]
|
||||
)
|
||||
for uid, doc_dict in serialized_uid_doc_map.items()
|
||||
}
|
||||
doc_uid_map = {v: k for k, v in serialized_uid_doc_map.items()}
|
||||
return cls(uid_doc_map=uid_doc_map, doc_uid_map=doc_uid_map)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import DocManager
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import xor_args
|
||||
@@ -63,6 +64,7 @@ class Chroma(VectorStore):
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
doc_manager: Optional[DocManager] = None,
|
||||
) -> None:
|
||||
"""Initialize with Chroma client."""
|
||||
try:
|
||||
@@ -97,6 +99,7 @@ class Chroma(VectorStore):
|
||||
else None,
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
self._doc_manager = doc_manager
|
||||
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
@@ -152,6 +155,10 @@ class Chroma(VectorStore):
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||
if self._doc_manager and ids:
|
||||
raise ValueError
|
||||
elif not ids:
|
||||
ids = self._doc_manager.add_texts(texts, metadatas=metadatas)
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
embeddings = None
|
||||
@@ -346,20 +353,40 @@ class Chroma(VectorStore):
|
||||
document_id (str): ID of the document to update.
|
||||
document (Document): Document to update.
|
||||
"""
|
||||
text = document.page_content
|
||||
metadata = document.metadata
|
||||
return self.update_existing_documents([document_id], [document])
|
||||
|
||||
def update_existing_documents(
|
||||
self, document_ids: List[str], documents: List[Document]
|
||||
) -> None:
|
||||
"""Update a document in the collection."""
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
if self._embedding_function is None:
|
||||
raise ValueError(
|
||||
"For update, you must specify an embedding function on creation."
|
||||
)
|
||||
embeddings = self._embedding_function.embed_documents(list(text))
|
||||
embeddings = self._embedding_function.embed_documents(list(texts))
|
||||
self._collection.update(document_ids, embeddings, list(texts), list(metadatas))
|
||||
|
||||
self._collection.update(
|
||||
ids=[document_id],
|
||||
embeddings=[embeddings[0]],
|
||||
documents=[text],
|
||||
metadatas=[metadata],
|
||||
)
|
||||
def update(self, documents: List[Document]) -> List[str]:
|
||||
if self._doc_manager:
|
||||
has_id = [self._doc_manager.contains_doc(doc) for doc in documents]
|
||||
old_ids = [
|
||||
self._doc_manager.get_doc_id(d) for h, d in zip(has_id, documents) if h
|
||||
]
|
||||
new_ids = self.add_documents(
|
||||
[d for h, d in zip(has_id, documents) if not h]
|
||||
)
|
||||
all_ids = []
|
||||
old_idx, new_idx = 0, 0
|
||||
for i, is_old in enumerate(has_id):
|
||||
if is_old:
|
||||
all_ids.append(old_ids[old_idx])
|
||||
old_idx += 1
|
||||
else:
|
||||
all_ids.append(new_ids[new_idx])
|
||||
new_idx += 1
|
||||
else:
|
||||
return self.add_documents(documents)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
|
||||
Reference in New Issue
Block a user