Compare commits

..

1 Commits

Author SHA1 Message Date
vowelparrot
774c405707 Another crazy 2023-06-09 12:48:14 -07:00
94 changed files with 1222 additions and 253 deletions

View File

@@ -29,6 +29,7 @@ class BaseMetadataCallbackHandler:
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
@@ -51,6 +52,7 @@ class BaseMetadataCallbackHandler:
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
@@ -85,6 +87,11 @@ class BaseMetadataCallbackHandler:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,

View File

@@ -1,15 +1,45 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.schema.base import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)
from langchain.schema.base import AgentAction, AgentFinish, BaseMessage, LLMResult
from langchain.schema.document import Document
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
@@ -183,6 +213,7 @@ class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
@@ -361,6 +392,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that can be used to handle callbacks from LangChain."""

View File

@@ -14,6 +14,7 @@ from typing import (
Generator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
@@ -27,6 +28,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
@@ -43,6 +45,7 @@ from langchain.schema.base import (
LLMResult,
get_buffer_string,
)
from langchain.schema.document import Document
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@@ -624,6 +627,91 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
)
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
"""Callback manager for retriever run."""
def get_child(self) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
_handle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
_handle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForRetrieverRun(
AsyncRunManager,
RetrieverManagerMixin,
):
"""Async callback manager for retriever run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when retriever ends running."""
await _ahandle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
await _ahandle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
@@ -733,6 +821,29 @@ class CallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
)
return CallbackManagerForRetrieverRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
@@ -856,6 +967,29 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
)
return AsyncCallbackManagerForRetrieverRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.schema.base import LLMResult
from langchain.schema.document import Document
class TracerException(Exception):
@@ -262,6 +263,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace(tool_run)
self._on_tool_error(tool_run)
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
retrieval_run = Run(
id=run_id,
name="Retriever",
parent_run_id=parent_run_id,
inputs={"query": query},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.retriever,
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
@@ -299,3 +359,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""

View File

@@ -122,3 +122,15 @@ class LangChainTracer(BaseTracer):
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

View File

@@ -94,6 +94,7 @@ class RunTypeEnum(str, Enum):
tool = "tool"
chain = "chain"
llm = "llm"
retriever = "retriever"
class RunBase(BaseModel):

View File

@@ -21,7 +21,9 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema.base import BaseMessage, BaseRetriever, Document
from langchain.schema.base import BaseMessage
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
@@ -87,7 +89,13 @@ class BaseConversationalRetrievalChain(Chain):
return _output_keys
@abstractmethod
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs."""
def _call(
@@ -107,7 +115,7 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = self._get_docs(new_question, inputs)
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@@ -122,7 +130,13 @@ class BaseConversationalRetrievalChain(Chain):
return output
@abstractmethod
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs."""
async def _acall(
@@ -141,7 +155,7 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = await self._aget_docs(new_question, inputs)
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@@ -187,12 +201,28 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
return docs[:num_docs]
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = self.retriever.get_relevant_documents(question)
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager_.get_child()
)
return self._reduce_tokens_below_limit(docs)
@classmethod
@@ -253,14 +283,26 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
)
return values
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
raise NotImplementedError("ChatVectorDBChain does not support async")
@classmethod

View File

@@ -8,9 +8,7 @@ import numpy as np
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
@@ -20,7 +18,8 @@ from langchain.chains.flare.prompts import (
from langchain.chains.llm import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import BasePromptTemplate
from langchain.schema.base import BaseRetriever, Generation
from langchain.schema.base import Generation
from langchain.schema.retriever import BaseRetriever
class _ResponseChain(LLMChain):
@@ -124,7 +123,7 @@ class FlareChain(Chain):
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.get_relevant_documents(question))
docs.extend(self.retriever.retrieve(question, callbacks=callbacks))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,

View File

@@ -115,7 +115,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return values
@abstractmethod
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs to run questioning over."""
def _call(
@@ -124,7 +129,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self._get_docs(inputs)
docs = self._get_docs(inputs, run_manager=_run_manager)
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@@ -141,7 +146,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return result
@abstractmethod
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
"""Get docs to run questioning over."""
async def _acall(
@@ -150,7 +160,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self._aget_docs(inputs)
docs = await self._aget_docs(inputs, run_manager=_run_manager)
answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@@ -180,10 +190,20 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""
return [self.input_docs_key, self.question_key]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[Document]:
return inputs.pop(self.input_docs_key)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
return inputs.pop(self.input_docs_key)
@property

View File

@@ -1,13 +1,17 @@
"""Question-answering with sources over an index."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema.base import BaseRetriever
from langchain.schema.retriever import BaseRetriever
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
@@ -40,12 +44,26 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
docs = self.retriever.get_relevant_documents(question)
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
) -> List[Document]:
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
docs = await self.retriever.aget_relevant_documents(question)
docs = await self.retriever.aretrieve(
question, callbacks=run_manager_.get_child()
)
return self._reduce_tokens_below_limit(docs)

View File

@@ -1,10 +1,14 @@
"""Question-answering with sources over a vector database."""
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
@@ -45,14 +49,24 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
question = inputs[self.question_key]
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
) -> List[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@root_validator()

View File

@@ -19,7 +19,8 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.base import VectorStore
@@ -94,7 +95,9 @@ class BaseRetrievalQA(Chain):
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self, question: str, *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
"""Get documents to do question answering over."""
def _call(
@@ -116,7 +119,7 @@ class BaseRetrievalQA(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = self._get_docs(question)
docs = self._get_docs(question, run_manager=_run_manager)
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@@ -127,7 +130,9 @@ class BaseRetrievalQA(Chain):
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
"""Get documents to do question answering over."""
async def _acall(
@@ -149,7 +154,7 @@ class BaseRetrievalQA(Chain):
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = await self._aget_docs(question)
docs = await self._aget_docs(question, run_manager=_run_manager)
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@@ -177,11 +182,22 @@ class RetrievalQA(BaseRetrievalQA):
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_documents(question)
def _get_docs(
self, question: str, *, run_manager: Optional[CallbackManagerForChainRun] = None
) -> List[Document]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.retriever.retrieve(question, run_manager=_run_manager.get_child())
async def _aget_docs(self, question: str) -> List[Document]:
return await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
*,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> List[Document]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
return await self.retriever.aretrieve(
question, callbacks=_run_manager.get_child()
)
@property
def _chain_type(self) -> str:
@@ -218,7 +234,9 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self, question: str, *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
@@ -231,7 +249,9 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
raise NotImplementedError("VectorDBQA does not support async")
@property

View File

@@ -15,7 +15,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
)
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema.base import BaseRetriever
from langchain.schema.retriever import BaseRetriever
class MultiRetrievalQAChain(MultiRouteChain):

View File

@@ -1,7 +1,7 @@
from typing import Callable, Union
from langchain.docstore.base import Docstore
from langchain.schema.base import Document
from langchain.schema.document import Document
class DocstoreFn(Docstore):

View File

@@ -1,3 +1,3 @@
from langchain.schema.base import Document
from langchain.schema.document import Document
__all__ = ["Document"]

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Iterator, List, Optional
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

View File

@@ -6,7 +6,7 @@ from typing import Iterator, List, Literal, Optional, Sequence, Union
from langchain.document_loaders.base import BaseBlobParser, BaseLoader
from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader
from langchain.document_loaders.parsers.registry import get_parser
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.text_splitter import TextSplitter
_PathLike = Union[str, Path]

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from typing import Iterator, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.utils import get_from_env
LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}"

View File

@@ -2,7 +2,7 @@ from typing import Iterator
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
class OpenAIWhisperParser(BaseBlobParser):

View File

@@ -6,7 +6,7 @@ from typing import Iterator, Mapping, Optional
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders.schema import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
class MimeTypeBasedParser(BaseBlobParser):

View File

@@ -3,7 +3,7 @@ from typing import Any, Iterator, Mapping, Optional
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
class PyPDFParser(BaseBlobParser):

View File

@@ -3,7 +3,7 @@ from typing import Iterator
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
class TextParser(BaseBlobParser):

View File

@@ -4,7 +4,7 @@ import re
from typing import Any, Callable, Generator, Iterable, List, Optional
from langchain.document_loaders.web_base import WebBaseLoader
from langchain.schema.base import Document
from langchain.schema.document import Document
def _default_parsing_function(content: Any) -> str:

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from langchain.embeddings.base import Embeddings
from langchain.math_utils import cosine_similarity
from langchain.schema.base import BaseDocumentTransformer, Document
from langchain.schema.document import BaseDocumentTransformer, Document
class _DocumentWithState(Document):

View File

@@ -14,13 +14,8 @@ from langchain.experimental.autonomous_agents.autogpt.prompt import AutoGPTPromp
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
FINISH_NAME,
)
from langchain.schema.base import (
AIMessage,
BaseMessage,
Document,
HumanMessage,
SystemMessage,
)
from langchain.schema.base import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain.schema.document import Document
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain.vectorstores.base import VectorStoreRetriever

View File

@@ -7,7 +7,8 @@ from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema.base import BaseMemory, Document
from langchain.schema.base import BaseMemory
from langchain.schema.document import Document
from langchain.utils import mock_now
logger = logging.getLogger(__name__)

View File

@@ -9,7 +9,7 @@ from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma

View File

@@ -6,7 +6,7 @@ from pydantic import Field
from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores.base import VectorStoreRetriever

View File

@@ -1,6 +1,11 @@
from typing import List
from typing import Any, List, Optional
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.utilities.arxiv import ArxivAPIWrapper
@@ -11,8 +16,20 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
It uses all ArxivAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
return self.load(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,8 +1,14 @@
"""Retriever wrapper for AWS Kendra."""
import re
from typing import Any, Dict, List
from langchain.schema.base import BaseRetriever, Document
import re
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class AwsKendraIndexRetriever(BaseRetriever):
@@ -84,12 +90,24 @@ Document Excerpt: {doc_excerpt}
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Run search on Kendra index and get top k documents
docs = get_relevant_documents('This is my query')
"""
return self._kendra_query(query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("AwsKendraIndexRetriever does not support async")

View File

@@ -1,14 +1,20 @@
"""Retriever wrapper for Azure Cognitive Search."""
from __future__ import annotations
import json
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.utils import get_from_dict_or_env
@@ -81,7 +87,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
return response_json["value"]
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
search_results = self._search(query)
return [
@@ -89,7 +101,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
for result in search_results
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
search_results = await self._asearch(query)
return [

View File

@@ -1,12 +1,17 @@
from __future__ import annotations
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
@@ -21,7 +26,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
url, json, headers = self._create_request(query)
response = requests.post(url, json=json, headers=headers)
results = response.json()["results"][0]["results"]
@@ -31,7 +42,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
docs.append(Document(page_content=content, metadata=d))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
url, json, headers = self._create_request(query)
if not self.aiosession:

View File

@@ -1,12 +1,18 @@
"""Retriever that wraps a base retriever and filters the results."""
from typing import List
from typing import Any, List, Optional
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class ContextualCompressionRetriever(BaseRetriever, BaseModel):
@@ -24,7 +30,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@@ -37,7 +49,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
compressed_docs = self.base_compressor.compress_documents(docs, query)
return list(compressed_docs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:

View File

@@ -1,9 +1,14 @@
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class DataberryRetriever(BaseRetriever):
@@ -21,7 +26,13 @@ class DataberryRetriever(BaseRetriever):
self.api_key = api_key
self.top_k = top_k
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
@@ -46,7 +57,13 @@ class DataberryRetriever(BaseRetriever):
for r in data["results"]
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",

View File

@@ -4,7 +4,7 @@ from typing import List, Sequence, Union
from pydantic import BaseModel
from langchain.schema.base import BaseDocumentTransformer, Document
from langchain.schema.document import BaseDocumentTransformer, Document
class BaseDocumentCompressor(BaseModel, ABC):

View File

@@ -10,7 +10,8 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
from langchain.retrievers.document_compressors.chain_extract_prompt import (
prompt_template,
)
from langchain.schema.base import BaseOutputParser, Document
from langchain.schema.base import BaseOutputParser
from langchain.schema.document import Document
def default_get_input(query: str, doc: Document) -> Dict[str, Any]:

View File

@@ -8,7 +8,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
from langchain.retrievers.document_compressors.chain_filter_prompt import (
prompt_template,
)
from langchain.schema.base import Document
from langchain.schema.document import Document
def _get_default_chain_prompt() -> PromptTemplate:

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Dict, Sequence
from pydantic import Extra, root_validator
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:

View File

@@ -13,7 +13,7 @@ from langchain.math_utils import cosine_similarity
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
from langchain.schema.base import Document
from langchain.schema.document import Document
class EmbeddingsFilter(BaseDocumentCompressor):

View File

@@ -1,11 +1,16 @@
"""Wrapper around Elasticsearch vector database."""
from __future__ import annotations
import uuid
from typing import Any, Iterable, List
from typing import Any, Iterable, List, Optional
from langchain.docstore.document import Document
from langchain.schema.base import BaseRetriever
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class ElasticSearchBM25Retriever(BaseRetriever):
@@ -111,7 +116,13 @@ class ElasticSearchBM25Retriever(BaseRetriever):
self.client.indices.refresh(index=self.index_name)
return ids
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict)
@@ -120,5 +131,11 @@ class ElasticSearchBM25Retriever(BaseRetriever):
docs.append(Document(page_content=r["_source"]["content"]))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -10,8 +10,13 @@ from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
@@ -39,7 +44,13 @@ class KNNRetriever(BaseRetriever, BaseModel):
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
@@ -61,5 +72,11 @@ class KNNRetriever(BaseRetriever, BaseModel):
]
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,8 +1,13 @@
from typing import Any, Dict, List, cast
from typing import Any, Dict, List, Optional, cast
from pydantic import BaseModel, Field
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class LlamaIndexRetriever(BaseRetriever, BaseModel):
@@ -11,7 +16,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
index: Any
query_kwargs: Dict = Field(default_factory=dict)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.indices.base import BaseGPTIndex
@@ -33,7 +44,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
@@ -43,7 +60,13 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
graph: Any
query_configs: List[Dict] = Field(default_factory=list)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.composability.graph import (
@@ -73,5 +96,11 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

View File

@@ -1,6 +1,11 @@
from typing import Any, List, Optional
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class MetalRetriever(BaseRetriever):
@@ -15,7 +20,13 @@ class MetalRetriever(BaseRetriever):
self.client: Metal = client
self.params = params or {}
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
results = self.client.search({"text": query}, **self.params)
final_results = []
for r in results["data"]:
@@ -23,5 +34,11 @@ class MetalRetriever(BaseRetriever):
final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,8 +1,14 @@
"""Milvus Retriever"""
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.milvus import Milvus
# TODO: Update to MilvusClient + Hybrid Search when available
@@ -36,8 +42,21 @@ class MilvusRetreiver(BaseRetriever):
"""
self.store.add_texts(texts, metadatas)
def get_relevant_documents(self, query: str) -> List[Document]:
return self.retriever.get_relevant_documents(query)
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
run_manager_ = run_manager or CallbackManagerForRetrieverRun.get_noop_manager()
return self.retriever.retrieve(query, callbacks=run_manager_.get_child())
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,11 +1,17 @@
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
import hashlib
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
def hash_text(text: str) -> str:
@@ -116,7 +122,13 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale
sparse_vec = self.sparse_encoder.encode_queries(query)
@@ -141,5 +153,11 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
# return search results as json
return final_result
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,6 +1,11 @@
from typing import List
from typing import Any, List, Optional
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.utilities.pupmed import PubMedAPIWrapper
@@ -11,8 +16,20 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
It uses all PubMedAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
return self.load_docs(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,10 +1,15 @@
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class RemoteLangChainRetriever(BaseRetriever, BaseModel):
@@ -15,7 +20,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
page_content_key: str = "page_content"
metadata_key: str = "metadata"
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.url, json={self.input_key: query}, headers=self.headers
)
@@ -27,7 +38,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
for r in result[self.response_key]
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST", self.url, headers=self.headers, json={self.input_key: query}

View File

@@ -1,10 +1,15 @@
"""Retriever that generates and executes structured queries over its own data source."""
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
@@ -12,7 +17,8 @@ from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate
@@ -65,7 +71,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@@ -90,7 +102,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
@classmethod

View File

@@ -10,8 +10,13 @@ from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
@@ -39,7 +44,13 @@ class SVMRetriever(BaseRetriever, BaseModel):
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
from sklearn import svm
query_embeds = np.array(self.embeddings.embed_query(query))
@@ -76,5 +87,11 @@ class SVMRetriever(BaseRetriever, BaseModel):
top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -2,13 +2,19 @@
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
class TFIDFRetriever(BaseRetriever, BaseModel):
@@ -58,7 +64,13 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform(
@@ -70,5 +82,11 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,11 +1,17 @@
"""Retriever that combines embedding similarity with recency in retrieving values."""
import datetime
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.base import VectorStore
@@ -80,7 +86,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
results[buffer_idx] = (doc, relevance)
return results
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
docs_and_scores = {
@@ -103,7 +115,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
result.append(buffered_doc)
return result
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError

View File

@@ -1,10 +1,16 @@
"""Wrapper for retrieving documents from Vespa."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
if TYPE_CHECKING:
from vespa.application import Vespa
@@ -48,12 +54,24 @@ class VespaRetriever(BaseRetriever):
docs.append(Document(page_content=page_content, metadata=metadata))
return docs
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
body = self._query_body.copy()
body["query"] = query
return self._query(body)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def get_relevant_documents_with_filter(

View File

@@ -6,8 +6,12 @@ from uuid import uuid4
from pydantic import Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.schema.base import BaseRetriever
from langchain.schema.retriever import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever):
@@ -83,7 +87,12 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
return ids
def get_relevant_documents(
self, query: str, where_filter: Optional[Dict[str, object]] = None
self,
query: str,
*,
where_filter: Optional[Dict[str, object]] = None,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs)
@@ -102,6 +111,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
return docs
async def aget_relevant_documents(
self, query: str, where_filter: Optional[Dict[str, object]] = None
self,
query: str,
*,
where_filter: Optional[Dict[str, object]] = None,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,6 +1,11 @@
from typing import List
from typing import Any, List, Optional
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.utilities.wikipedia import WikipediaAPIWrapper
@@ -11,8 +16,20 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
It uses all WikipediaAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
return self.load(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -1,8 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.schema.base import BaseRetriever, Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
if TYPE_CHECKING:
from zep_python import MemorySearchResult
@@ -54,7 +59,12 @@ class ZepRetriever(BaseRetriever):
]
def get_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
self,
query: str,
*,
metadata: Optional[Dict] = None,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload
@@ -69,7 +79,12 @@ class ZepRetriever(BaseRetriever):
return self._search_result_to_doc(results)
async def aget_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
self,
query: str,
*,
metadata: Optional[Dict] = None,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload

View File

@@ -1,8 +1,14 @@
"""Zilliz Retriever"""
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever, Document
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.zilliz import Zilliz
# TODO: Update to ZillizClient + Hybrid Search when available
@@ -36,8 +42,21 @@ class ZillizRetreiver(BaseRetriever):
"""
self.store.add_texts(texts, metadatas)
def get_relevant_documents(self, query: str) -> List[Document]:
return self.retriever.get_relevant_documents(query)
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
_run_manager = run_manager or CallbackManagerForRetrieverRun.get_noop_manager()
return self.retriever.retrieve(query, callbacks=_run_manager.get_child())
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

View File

@@ -6,15 +6,12 @@ from langchain.schema.base import (
AgentFinish,
AIMessage,
BaseChatMessageHistory,
BaseDocumentTransformer,
BaseMemory,
BaseMessage,
BaseOutputParser,
BaseRetriever,
ChatGeneration,
ChatMessage,
ChatResult,
Document,
Generation,
HumanMessage,
LLMResult,
@@ -30,6 +27,8 @@ from langchain.schema.base import (
messages_from_dict,
messages_to_dict,
)
from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.retriever import BaseRetriever
__all__ = [
"AIMessage",

View File

@@ -285,37 +285,6 @@ class BaseChatMessageHistory(ABC):
"""Remove all messages from the store"""
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
metadata: dict = Field(default_factory=dict)
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
@abstractmethod
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
# For backwards compatibility
@@ -405,19 +374,3 @@ class OutputParserException(ValueError):
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm
class BaseDocumentTransformer(ABC):
"""Base interface for transforming documents."""
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents."""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents."""

View File

@@ -0,0 +1,30 @@
"""Schema for a document.""" ""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Sequence
from pydantic import BaseModel, Field
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
metadata: dict = Field(default_factory=dict)
class BaseDocumentTransformer(ABC):
"""Base interface for transforming documents."""
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents."""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents."""

View File

@@ -0,0 +1,133 @@
"""Schema for a document.""" ""
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import (
Any,
List,
Optional,
)
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForRetrieverRun,
CallbackManager,
CallbackManagerForRetrieverRun,
Callbacks,
)
from langchain.schema.document import Document
class BaseRetriever(ABC):
"""Base interface for a retriever."""
_new_arg_supported: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
cls._new_arg_supported = (
signature(cls.get_relevant_documents).parameters.get("run_manager")
is not None
)
@abstractmethod
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
@abstractmethod
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
def retrieve(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Retrieve documents.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
# TODO: maybe also pass through run_manager is _run supports kwargs
if self._new_arg_supported:
result = self.get_relevant_documents(
query, run_manager=run_manager, **kwargs
)
else:
result = self.get_relevant_documents(query)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aretrieve(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = await callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = await self.aget_relevant_documents(
query, run_manager=run_manager, **kwargs
)
else:
result = await self.aget_relevant_documents(query)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result

View File

@@ -22,8 +22,7 @@ from typing import (
Union,
)
from langchain.docstore.document import Document
from langchain.schema.base import BaseDocumentTransformer
from langchain.schema.document import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.schema.base import Document
from langchain.schema.document import Document
logger = logging.getLogger(__name__)

View File

@@ -7,7 +7,7 @@ from typing import List
from pydantic import BaseModel, Extra
from langchain.schema.base import Document
from langchain.schema.document import Document
logger = logging.getLogger(__name__)

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.schema.base import Document
from langchain.schema.document import Document
logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,5 @@
"""Interface for vector stores."""
from __future__ import annotations
import asyncio
@@ -20,9 +21,13 @@ from typing import (
from pydantic import BaseModel, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.base import BaseRetriever
from langchain.schema.retriever import BaseRetriever
VST = TypeVar("VST", bound="VectorStore")
@@ -387,7 +392,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
@@ -405,7 +416,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs

View File

@@ -5,7 +5,7 @@ import numpy as np
from pydantic import Field
from langchain.embeddings.base import Embeddings
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance

View File

@@ -1,4 +1,5 @@
"""Wrapper around Redis vector database."""
from __future__ import annotations
import json
@@ -21,6 +22,10 @@ from typing import (
import numpy as np
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -567,7 +572,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
elif self.search_type == "similarity_limit":
@@ -578,7 +589,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:

View File

@@ -1,4 +1,5 @@
"""Wrapper around SingleStore DB."""
from __future__ import annotations
import json
@@ -15,6 +16,10 @@ from typing import (
from sqlalchemy.pool import QueuePool
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
@@ -359,14 +364,26 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
k: int = 4
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError(
"SingleStoreDBVectorStoreRetriever does not support async"
)

View File

@@ -4,7 +4,7 @@ import itertools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from langchain.embeddings.base import Embeddings
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores import VectorStore
if TYPE_CHECKING:

View File

@@ -11,7 +11,7 @@ import requests
from pydantic import Field
from langchain.embeddings.base import Embeddings
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever

View File

@@ -1,7 +1,7 @@
from typing import List
from langchain.document_loaders.arxiv import ArxivLoader
from langchain.schema.base import Document
from langchain.schema.document import Document
def assert_docs(docs: List[Document]) -> None:

View File

@@ -2,7 +2,7 @@ import pandas as pd
import pytest
from langchain.document_loaders import DataFrameLoader
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.fixture

View File

@@ -5,7 +5,7 @@ from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.text_splitter import CharacterTextSplitter

View File

@@ -1,7 +1,7 @@
"""Integration test for LLMChainExtractor."""
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_llm_construction_with_kwargs() -> None:

View File

@@ -1,7 +1,7 @@
"""Integration test for llm-based relevant doc filtering."""
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_llm_chain_filter() -> None:

View File

@@ -4,7 +4,7 @@ import numpy as np
from langchain.document_transformers import _DocumentWithState
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_embeddings_filter() -> None:

View File

@@ -4,7 +4,7 @@ from typing import List
import pytest
from langchain.retrievers import ArxivRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.fixture

View File

@@ -2,7 +2,7 @@
import pytest
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_azure_cognitive_search_get_relevant_documents() -> None:

View File

@@ -4,7 +4,7 @@ from typing import List
import pytest
from langchain.retrievers import PubMedRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.fixture

View File

@@ -4,7 +4,7 @@ from typing import List
import pytest
from langchain.retrievers import WikipediaRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.fixture

View File

@@ -4,7 +4,7 @@ from langchain.document_transformers import (
_DocumentWithState,
)
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_embeddings_redundant_filter() -> None:

View File

@@ -4,7 +4,7 @@ from typing import Any, List
import pytest
from langchain.agents.load_tools import load_tools
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.tools.base import BaseTool
from langchain.utilities import ArxivAPIWrapper

View File

@@ -4,7 +4,7 @@ from typing import Any, List
import pytest
from langchain.agents.load_tools import load_tools
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.tools.base import BaseTool
from langchain.utilities import PubMedAPIWrapper

View File

@@ -3,7 +3,7 @@ from typing import List
import pytest
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.utilities import WikipediaAPIWrapper

View File

@@ -6,7 +6,7 @@ from vcr.request import Request
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.text_splitter import CharacterTextSplitter
# Those environment variables turn on Deep Lake pytest mode.

View File

@@ -4,7 +4,7 @@ from typing import List
import numpy as np
import pytest
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores.docarray import DocArrayHnswSearch
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings

View File

@@ -4,7 +4,7 @@ from typing import List
import numpy as np
import pytest
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores.docarray import DocArrayInMemorySearch
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings

View File

@@ -1,5 +1,5 @@
from langchain.docstore.arbitrary_fn import DocstoreFn
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_document_found() -> None:

View File

@@ -7,7 +7,7 @@ import pytest
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.generic import MimeTypeBasedParser
from langchain.schema.base import Document
from langchain.schema.document import Document
class TestMimeBasedParser:

View File

@@ -3,7 +3,7 @@ from typing import Iterator
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema.base import Document
from langchain.schema.document import Document
def test_base_blob_parser() -> None:

View File

@@ -9,7 +9,7 @@ import pytest
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob, FileSystemBlobLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.fixture

View File

@@ -1,7 +1,7 @@
import pytest
from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
@pytest.mark.requires("sklearn")

View File

@@ -10,7 +10,7 @@ from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
_get_hours_passed,
)
from langchain.schema.base import Document
from langchain.schema.document import Document
from langchain.vectorstores.base import VectorStore

View File

@@ -7,7 +7,7 @@ import pytest
from pytest_mock import MockerFixture
from langchain.retrievers import ZepRetriever
from langchain.schema.base import Document
from langchain.schema.document import Document
if TYPE_CHECKING:
from zep_python import MemorySearchResult, ZepClient