mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
5 Commits
langchain-
...
dev2049/ro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bda5b0c050 | ||
|
|
37dd0304de | ||
|
|
562ee92e70 | ||
|
|
7912e13522 | ||
|
|
b6779b91ea |
@@ -1,12 +1,28 @@
|
|||||||
"""Interface to access to place that stores documents."""
|
"""Interface to access to place that stores documents."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Union
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
UID_TYPE = TypeVar("UID_TYPE")
|
||||||
|
|
||||||
|
|
||||||
class Docstore(ABC):
|
class Docstore(ABC):
|
||||||
"""Interface to access to place that stores documents."""
|
"""Interface to place that stores documents."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, search: str) -> Union[str, Document]:
|
def search(self, search: str) -> Union[str, Document]:
|
||||||
@@ -23,3 +39,135 @@ class AddableMixin(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add(self, texts: Dict[str, Document]) -> None:
|
def add(self, texts: Dict[str, Document]) -> None:
|
||||||
"""Add more documents."""
|
"""Add more documents."""
|
||||||
|
|
||||||
|
|
||||||
|
class DocManager(ABC, Generic[UID_TYPE]):
|
||||||
|
def add(self, doc: Document) -> UID_TYPE:
|
||||||
|
if self.contains_doc(doc):
|
||||||
|
raise ValueError
|
||||||
|
uid = self.generate_uid()
|
||||||
|
self._add(doc, uid)
|
||||||
|
return uid
|
||||||
|
|
||||||
|
def lazy_add_docs(self, docs: Iterable[Document]) -> Iterator[UID_TYPE]:
|
||||||
|
for doc in docs:
|
||||||
|
yield self.add(doc)
|
||||||
|
|
||||||
|
def add_docs(self, docs: Iterable[Document]) -> List[UID_TYPE]:
|
||||||
|
return list(self.lazy_add_docs(docs))
|
||||||
|
|
||||||
|
def add_text(self, text: str, metadata: Optional[dict] = None) -> UID_TYPE:
|
||||||
|
_metadata = metadata or {}
|
||||||
|
return self.add(Document(page_content=text, metadata=_metadata))
|
||||||
|
|
||||||
|
def lazy_add_texts(
|
||||||
|
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
|
||||||
|
) -> Iterator[UID_TYPE]:
|
||||||
|
_metadatas = metadatas or ({} for _ in texts)
|
||||||
|
for text, metadata in zip(texts, _metadatas):
|
||||||
|
yield self.add_text(text, metadata=metadata)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None
|
||||||
|
) -> List[UID_TYPE]:
|
||||||
|
return list(self.lazy_add_texts(texts, metadatas=metadatas))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add(self, doc: Document, uid: UID_TYPE) -> None:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
def delete(self, uid: UID_TYPE) -> None:
|
||||||
|
""""""
|
||||||
|
if not self.get(uid):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return self._delete(uid)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _delete(self, uid: UID_TYPE) -> None:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def contains_doc(self, doc: Document) -> bool:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_doc_id(self, doc: Document) -> UID_TYPE:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, uid: UID_TYPE) -> Document:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_uid(self) -> UID_TYPE:
|
||||||
|
""""""
|
||||||
|
uid = self._generate_uid()
|
||||||
|
if self.get(uid):
|
||||||
|
raise ValueError
|
||||||
|
return uid
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_uid(self) -> UID_TYPE:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
def persist(self, path: Union[str, Path]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: Union[str, Path]) -> DocManager:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleDocManager(DocManager[str]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uid_doc_map: Optional[Dict[str, Document]] = None,
|
||||||
|
doc_uid_map: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
self.uid_doc_map = uid_doc_map or {}
|
||||||
|
self.doc_uid_map = doc_uid_map or {}
|
||||||
|
if len(self.uid_doc_map) != len(self.doc_uid_map):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def _add(self, doc: Document, uid: str) -> None:
|
||||||
|
self.uid_doc_map[uid] = doc
|
||||||
|
self.doc_uid_map[self.serialize(doc)] = uid
|
||||||
|
|
||||||
|
def _delete(self, uid: str) -> None:
|
||||||
|
doc = self.uid_doc_map[uid]
|
||||||
|
del self.doc_uid_map[self.serialize(doc)]
|
||||||
|
del self.uid_doc_map[uid]
|
||||||
|
|
||||||
|
def contains_doc(self, doc: Document) -> bool:
|
||||||
|
return self.serialize(doc) in self.doc_uid_map
|
||||||
|
|
||||||
|
def get_doc_id(self, doc: Document) -> str:
|
||||||
|
return self.doc_uid_map[self.serialize(doc)]
|
||||||
|
|
||||||
|
def serialize(self, doc: Document) -> str:
|
||||||
|
# Assumes metadata is JSON-serializable.
|
||||||
|
return json.dumps(doc.dict(), sort_keys=True)
|
||||||
|
|
||||||
|
def _generate_uid(self) -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def persist(self, path: Union[str, Path]) -> None:
|
||||||
|
uid_doc_map = {
|
||||||
|
uid: self.serialize(doc) for uid, doc in self.uid_doc_map.items()
|
||||||
|
}
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(uid_doc_map, f, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: Union[str, Path]) -> SimpleDocManager:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
serialized_uid_doc_map = json.load(f)
|
||||||
|
uid_doc_map = {
|
||||||
|
uid: Document(
|
||||||
|
page_content=doc_dict["page_content"], metadata=doc_dict["metadata"]
|
||||||
|
)
|
||||||
|
for uid, doc_dict in serialized_uid_doc_map.items()
|
||||||
|
}
|
||||||
|
doc_uid_map = {v: k for k, v in serialized_uid_doc_map.items()}
|
||||||
|
return cls(uid_doc_map=uid_doc_map, doc_uid_map=doc_uid_map)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.base import DocManager
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.utils import xor_args
|
from langchain.utils import xor_args
|
||||||
@@ -63,6 +64,7 @@ class Chroma(VectorStore):
|
|||||||
client_settings: Optional[chromadb.config.Settings] = None,
|
client_settings: Optional[chromadb.config.Settings] = None,
|
||||||
collection_metadata: Optional[Dict] = None,
|
collection_metadata: Optional[Dict] = None,
|
||||||
client: Optional[chromadb.Client] = None,
|
client: Optional[chromadb.Client] = None,
|
||||||
|
doc_manager: Optional[DocManager] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize with Chroma client."""
|
"""Initialize with Chroma client."""
|
||||||
try:
|
try:
|
||||||
@@ -97,6 +99,7 @@ class Chroma(VectorStore):
|
|||||||
else None,
|
else None,
|
||||||
metadata=collection_metadata,
|
metadata=collection_metadata,
|
||||||
)
|
)
|
||||||
|
self._doc_manager = doc_manager
|
||||||
|
|
||||||
@xor_args(("query_texts", "query_embeddings"))
|
@xor_args(("query_texts", "query_embeddings"))
|
||||||
def __query_collection(
|
def __query_collection(
|
||||||
@@ -152,15 +155,21 @@ class Chroma(VectorStore):
|
|||||||
List[str]: List of IDs of the added texts.
|
List[str]: List of IDs of the added texts.
|
||||||
"""
|
"""
|
||||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||||
if ids is None:
|
if self._doc_manager and ids:
|
||||||
ids = [str(uuid.uuid1()) for _ in texts]
|
raise ValueError
|
||||||
|
elif self._doc_manager:
|
||||||
|
_ids = self._doc_manager.add_texts(texts, metadatas=metadatas)
|
||||||
|
elif ids is None:
|
||||||
|
_ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
|
else:
|
||||||
|
_ids = ids
|
||||||
embeddings = None
|
embeddings = None
|
||||||
if self._embedding_function is not None:
|
if self._embedding_function is not None:
|
||||||
embeddings = self._embedding_function.embed_documents(list(texts))
|
embeddings = self._embedding_function.embed_documents(list(texts))
|
||||||
self._collection.add(
|
self._collection.add(
|
||||||
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
|
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=_ids
|
||||||
)
|
)
|
||||||
return ids
|
return _ids
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
@@ -346,20 +355,42 @@ class Chroma(VectorStore):
|
|||||||
document_id (str): ID of the document to update.
|
document_id (str): ID of the document to update.
|
||||||
document (Document): Document to update.
|
document (Document): Document to update.
|
||||||
"""
|
"""
|
||||||
text = document.page_content
|
return self.update_existing_documents([document_id], [document])
|
||||||
metadata = document.metadata
|
|
||||||
|
def update_existing_documents(
|
||||||
|
self, document_ids: List[str], documents: List[Document]
|
||||||
|
) -> None:
|
||||||
|
"""Update a document in the collection."""
|
||||||
|
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||||
|
texts_list = list(texts)
|
||||||
if self._embedding_function is None:
|
if self._embedding_function is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For update, you must specify an embedding function on creation."
|
"For update, you must specify an embedding function on creation."
|
||||||
)
|
)
|
||||||
embeddings = self._embedding_function.embed_documents(list(text))
|
embeddings = self._embedding_function.embed_documents(texts_list)
|
||||||
|
self._collection.update(document_ids, embeddings, texts_list, list(metadatas))
|
||||||
|
|
||||||
self._collection.update(
|
def update(self, documents: List[Document]) -> List[str]:
|
||||||
ids=[document_id],
|
if self._doc_manager:
|
||||||
embeddings=[embeddings[0]],
|
has_id = [self._doc_manager.contains_doc(doc) for doc in documents]
|
||||||
documents=[text],
|
old_ids = [
|
||||||
metadatas=[metadata],
|
self._doc_manager.get_doc_id(d) for h, d in zip(has_id, documents) if h
|
||||||
)
|
]
|
||||||
|
new_ids = self.add_documents(
|
||||||
|
[d for h, d in zip(has_id, documents) if not h]
|
||||||
|
)
|
||||||
|
all_ids = []
|
||||||
|
old_idx, new_idx = 0, 0
|
||||||
|
for i, is_old in enumerate(has_id):
|
||||||
|
if is_old:
|
||||||
|
all_ids.append(old_ids[old_idx])
|
||||||
|
old_idx += 1
|
||||||
|
else:
|
||||||
|
all_ids.append(new_ids[new_idx])
|
||||||
|
new_idx += 1
|
||||||
|
return all_ids
|
||||||
|
else:
|
||||||
|
return self.add_documents(documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
|
|||||||
Reference in New Issue
Block a user