mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-07 01:30:24 +00:00
Compare commits
1 Commits
vwp/schema
...
vwp/retrie
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
774c405707 |
@@ -29,6 +29,7 @@ class BaseMetadataCallbackHandler:
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
@@ -51,6 +52,7 @@ class BaseMetadataCallbackHandler:
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
@@ -85,6 +87,11 @@ class BaseMetadataCallbackHandler:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
|
||||
@@ -1,15 +1,45 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.base import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.base import AgentAction, AgentFinish, BaseMessage, LLMResult
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
"""Mixin for Retriever callbacks."""
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
|
||||
|
||||
class LLMManagerMixin:
|
||||
@@ -183,6 +213,7 @@ class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
RunManagerMixin,
|
||||
):
|
||||
@@ -361,6 +392,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever end."""
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -27,6 +28,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
@@ -43,6 +45,7 @@ from langchain.schema.base import (
|
||||
LLMResult,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
@@ -624,6 +627,91 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
|
||||
"""Callback manager for retriever run."""
|
||||
|
||||
def get_child(self) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForRetrieverRun(
|
||||
AsyncRunManager,
|
||||
RetrieverManagerMixin,
|
||||
):
|
||||
"""Async callback manager for retriever run."""
|
||||
|
||||
def get_child(self) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
return manager
|
||||
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when retriever ends running."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
"ignore_retriever",
|
||||
documents,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when retriever errors."""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
"ignore_retriever",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackManager):
|
||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||
|
||||
@@ -733,6 +821,29 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> CallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
)
|
||||
|
||||
return CallbackManagerForRetrieverRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
@@ -856,6 +967,29 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> AsyncCallbackManagerForRetrieverRun:
|
||||
"""Run when retriever starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_start",
|
||||
"ignore_retriever",
|
||||
query,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
)
|
||||
|
||||
return AsyncCallbackManagerForRetrieverRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
|
||||
@@ -3,12 +3,13 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.schema.base import LLMResult
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -262,6 +263,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name="Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
start_time=datetime.utcnow(),
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.retriever,
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
@@ -299,3 +359,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
|
||||
@@ -122,3 +122,15 @@ class LangChainTracer(BaseTracer):
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
@@ -94,6 +94,7 @@ class RunTypeEnum(str, Enum):
|
||||
tool = "tool"
|
||||
chain = "chain"
|
||||
llm = "llm"
|
||||
retriever = "retriever"
|
||||
|
||||
|
||||
class RunBase(BaseModel):
|
||||
|
||||
@@ -21,7 +21,9 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema.base import BaseMessage, BaseRetriever, Document
|
||||
from langchain.schema.base import BaseMessage
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
@@ -87,7 +89,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return _output_keys
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(
|
||||
@@ -107,7 +115,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -122,7 +130,13 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(
|
||||
@@ -141,7 +155,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
docs = await self._aget_docs(new_question, inputs)
|
||||
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@@ -187,12 +201,28 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager_.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
@@ -253,14 +283,26 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||
return self.vectorstore.similarity_search(
|
||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||
)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,9 +8,7 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.flare.prompts import (
|
||||
PROMPT,
|
||||
@@ -20,7 +18,8 @@ from langchain.chains.flare.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema.base import BaseRetriever, Generation
|
||||
from langchain.schema.base import Generation
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
@@ -124,7 +123,7 @@ class FlareChain(Chain):
|
||||
callbacks = _run_manager.get_child()
|
||||
docs = []
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.get_relevant_documents(question))
|
||||
docs.extend(self.retriever.retrieve(question, callbacks=callbacks))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.predict(
|
||||
user_input=user_input,
|
||||
|
||||
@@ -115,7 +115,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(
|
||||
@@ -124,7 +129,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = self._get_docs(inputs)
|
||||
docs = self._get_docs(inputs, run_manager=_run_manager)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -141,7 +146,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -150,7 +160,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = await self._aget_docs(inputs)
|
||||
docs = await self._aget_docs(inputs, run_manager=_run_manager)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
@@ -180,10 +190,20 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""
|
||||
return [self.input_docs_key, self.question_key]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Question-answering with sources over an index."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.base import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
@@ -40,12 +44,26 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
docs = self.retriever.retrieve(question, callbacks=run_manager_.get_child())
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
docs = await self.retriever.aretrieve(
|
||||
question, callbacks=run_manager_.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Question-answering with sources over a vector database."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -45,14 +49,24 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||
|
||||
@root_validator()
|
||||
|
||||
@@ -19,7 +19,8 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -94,7 +95,9 @@ class BaseRetrievalQA(Chain):
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(
|
||||
@@ -116,7 +119,7 @@ class BaseRetrievalQA(Chain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = self._get_docs(question)
|
||||
docs = self._get_docs(question, run_manager=_run_manager)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -127,7 +130,9 @@ class BaseRetrievalQA(Chain):
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(
|
||||
@@ -149,7 +154,7 @@ class BaseRetrievalQA(Chain):
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = await self._aget_docs(question)
|
||||
docs = await self._aget_docs(question, run_manager=_run_manager)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
@@ -177,11 +182,22 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: Optional[CallbackManagerForChainRun] = None
|
||||
) -> List[Document]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.retriever.retrieve(question, run_manager=_run_manager.get_child())
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
return await self.retriever.aget_relevant_documents(question)
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> List[Document]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
return await self.retriever.aretrieve(
|
||||
question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@@ -218,7 +234,9 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
def _get_docs(
|
||||
self, question: str, *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
@@ -231,7 +249,9 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
async def _aget_docs(
|
||||
self, question: str, *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQA does not support async")
|
||||
|
||||
@property
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain.chains.router.multi_retrieval_prompt import (
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.base import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class MultiRetrievalQAChain(MultiRouteChain):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Callable, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class DocstoreFn(Docstore):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
__all__ = ["Document"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterator, List, Literal, Optional, Sequence, Union
|
||||
from langchain.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader
|
||||
from langchain.document_loaders.parsers.registry import get_parser
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
_PathLike = Union[str, Path]
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}"
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Iterator
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class OpenAIWhisperParser(BaseBlobParser):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterator, Mapping, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders.schema import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class MimeTypeBasedParser(BaseBlobParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Iterator, Mapping, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class PyPDFParser(BaseBlobParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Iterator
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class TextParser(BaseBlobParser):
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
from typing import Any, Callable, Generator, Iterable, List, Optional
|
||||
|
||||
from langchain.document_loaders.web_base import WebBaseLoader
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def _default_parsing_function(content: Any) -> str:
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema.base import BaseDocumentTransformer, Document
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class _DocumentWithState(Document):
|
||||
|
||||
@@ -14,13 +14,8 @@ from langchain.experimental.autonomous_agents.autogpt.prompt import AutoGPTPromp
|
||||
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
|
||||
FINISH_NAME,
|
||||
)
|
||||
from langchain.schema.base import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
Document,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.base import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.document import Document
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
@@ -7,7 +7,8 @@ from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema.base import BaseMemory, Document
|
||||
from langchain.schema.base import BaseMemory
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import mock_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import Field
|
||||
|
||||
from langchain.memory.chat_memory import BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
|
||||
@@ -11,8 +16,20 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
It uses all ArxivAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
"""Retriever wrapper for AWS Kendra."""
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class AwsKendraIndexRetriever(BaseRetriever):
|
||||
@@ -84,12 +90,24 @@ Document Excerpt: {doc_excerpt}
|
||||
|
||||
return [self._get_top_n_results(response, i) for i in range(0, r_count)]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
docs = get_relevant_documents('This is my query')
|
||||
"""
|
||||
return self._kendra_query(query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("AwsKendraIndexRetriever does not support async")
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
"""Retriever wrapper for Azure Cognitive Search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@@ -81,7 +87,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
return response_json["value"]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
search_results = self._search(query)
|
||||
|
||||
return [
|
||||
@@ -89,7 +101,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
search_results = await self._asearch(query)
|
||||
|
||||
return [
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
@@ -21,7 +26,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
results = response.json()["results"][0]["results"]
|
||||
@@ -31,7 +42,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
docs.append(Document(page_content=content, metadata=d))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
|
||||
if not self.aiosession:
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Retriever that wraps a base retriever and filters the results."""
|
||||
from typing import List
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
@@ -24,7 +30,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -37,7 +49,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
compressed_docs = self.base_compressor.compress_documents(docs, query)
|
||||
return list(compressed_docs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class DataberryRetriever(BaseRetriever):
|
||||
@@ -21,7 +26,13 @@ class DataberryRetriever(BaseRetriever):
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.datastore_url,
|
||||
json={
|
||||
@@ -46,7 +57,13 @@ class DataberryRetriever(BaseRetriever):
|
||||
for r in data["results"]
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST",
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema.base import BaseDocumentTransformer, Document
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class BaseDocumentCompressor(BaseModel, ABC):
|
||||
|
||||
@@ -10,7 +10,8 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
|
||||
from langchain.retrievers.document_compressors.chain_extract_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
from langchain.schema.base import BaseOutputParser, Document
|
||||
from langchain.schema.base import BaseOutputParser
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
|
||||
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Dict, Sequence
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.math_utils import cosine_similarity
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Iterable, List
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.base import BaseRetriever
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
@@ -111,7 +116,13 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
return ids
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
query_dict = {"query": {"match": {"content": query}}}
|
||||
res = self.client.search(index=self.index_name, body=query_dict)
|
||||
|
||||
@@ -120,5 +131,11 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
docs.append(Document(page_content=r["_source"]["content"]))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -10,8 +10,13 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
@@ -39,7 +44,13 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
# calc L2 norm
|
||||
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
|
||||
@@ -61,5 +72,11 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
]
|
||||
return top_k_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Any, Dict, List, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
@@ -11,7 +16,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
index: Any
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.indices.base import BaseGPTIndex
|
||||
@@ -33,7 +44,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||
|
||||
|
||||
@@ -43,7 +60,13 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
graph: Any
|
||||
query_configs: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.composability.graph import (
|
||||
@@ -73,5 +96,11 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class MetalRetriever(BaseRetriever):
|
||||
@@ -15,7 +20,13 @@ class MetalRetriever(BaseRetriever):
|
||||
self.client: Metal = client
|
||||
self.params = params or {}
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
results = self.client.search({"text": query}, **self.params)
|
||||
final_results = []
|
||||
for r in results["data"]:
|
||||
@@ -23,5 +34,11 @@ class MetalRetriever(BaseRetriever):
|
||||
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
||||
return final_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
"""Milvus Retriever"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
|
||||
# TODO: Update to MilvusClient + Hybrid Search when available
|
||||
@@ -36,8 +42,21 @@ class MilvusRetreiver(BaseRetriever):
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(query)
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
run_manager_ = run_manager or CallbackManagerForRetrieverRun.get_noop_manager()
|
||||
return self.retriever.retrieve(query, callbacks=run_manager_.get_child())
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
def hash_text(text: str) -> str:
|
||||
@@ -116,7 +122,13 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from pinecone_text.hybrid import hybrid_convex_scale
|
||||
|
||||
sparse_vec = self.sparse_encoder.encode_queries(query)
|
||||
@@ -141,5 +153,11 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
# return search results as json
|
||||
return final_result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
@@ -11,8 +16,20 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
@@ -15,7 +20,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
page_content_key: str = "page_content"
|
||||
metadata_key: str = "metadata"
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.url, json={self.input_key: query}, headers=self.headers
|
||||
)
|
||||
@@ -27,7 +38,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
for r in result[self.response_key]
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST", self.url, headers=self.headers, json={self.input_key: query}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Retriever that generates and executes structured queries over its own data source."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
@@ -12,7 +17,8 @@ from langchain.retrievers.self_query.chroma import ChromaTranslator
|
||||
from langchain.retrievers.self_query.pinecone import PineconeTranslator
|
||||
from langchain.retrievers.self_query.qdrant import QdrantTranslator
|
||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate
|
||||
|
||||
|
||||
@@ -65,7 +71,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +102,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,8 +10,13 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
@@ -39,7 +44,13 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from sklearn import svm
|
||||
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
@@ -76,5 +87,11 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||
return top_k_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,13 +2,19 @@
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
@@ -58,7 +64,13 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
query_vec = self.vectorizer.transform(
|
||||
@@ -70,5 +82,11 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
return return_docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -80,7 +86,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
results[buffer_idx] = (doc, relevance)
|
||||
return results
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
current_time = datetime.datetime.now()
|
||||
docs_and_scores = {
|
||||
@@ -103,7 +115,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
"""Wrapper for retrieving documents from Vespa."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vespa.application import Vespa
|
||||
@@ -48,12 +54,24 @@ class VespaRetriever(BaseRetriever):
|
||||
docs.append(Document(page_content=page_content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
body = self._query_body.copy()
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_relevant_documents_with_filter(
|
||||
|
||||
@@ -6,8 +6,12 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.base import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
|
||||
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
@@ -83,7 +87,12 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
return ids
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Look up similar documents in Weaviate."""
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
@@ -102,6 +111,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
|
||||
|
||||
@@ -11,8 +16,20 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
It uses all WikipediaAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import MemorySearchResult
|
||||
@@ -54,7 +59,12 @@ class ZepRetriever(BaseRetriever):
|
||||
]
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
metadata: Optional[Dict] = None,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
@@ -69,7 +79,12 @@ class ZepRetriever(BaseRetriever):
|
||||
return self._search_result_to_doc(results)
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
metadata: Optional[Dict] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
"""Zilliz Retriever"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever, Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.zilliz import Zilliz
|
||||
|
||||
# TODO: Update to ZillizClient + Hybrid Search when available
|
||||
@@ -36,8 +42,21 @@ class ZillizRetreiver(BaseRetriever):
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(query)
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
_run_manager = run_manager or CallbackManagerForRetrieverRun.get_noop_manager()
|
||||
return self.retriever.retrieve(query, callbacks=_run_manager.get_child())
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -6,15 +6,12 @@ from langchain.schema.base import (
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseChatMessageHistory,
|
||||
BaseDocumentTransformer,
|
||||
BaseMemory,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
BaseRetriever,
|
||||
ChatGeneration,
|
||||
ChatMessage,
|
||||
ChatResult,
|
||||
Document,
|
||||
Generation,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
@@ -30,6 +27,8 @@ from langchain.schema.base import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
|
||||
@@ -285,37 +285,6 @@ class BaseChatMessageHistory(ABC):
|
||||
"""Remove all messages from the store"""
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Interface for interacting with a document."""
|
||||
|
||||
page_content: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
|
||||
|
||||
@@ -405,19 +374,3 @@ class OutputParserException(ValueError):
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Base interface for transforming documents."""
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents."""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents."""
|
||||
|
||||
30
langchain/schema/document.py
Normal file
30
langchain/schema/document.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Schema for a document.""" ""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Interface for interacting with a document."""
|
||||
|
||||
page_content: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Base interface for transforming documents."""
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents."""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents."""
|
||||
133
langchain/schema/retriever.py
Normal file
133
langchain/schema/retriever.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Schema for a document.""" ""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base interface for a retriever."""
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._new_arg_supported = (
|
||||
signature(cls.get_relevant_documents).parameters.get("run_manager")
|
||||
is not None
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
def retrieve(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
||||
if self._new_arg_supported:
|
||||
result = self.get_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self.get_relevant_documents(query)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aretrieve(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self.aget_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self.aget_relevant_documents(query)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
@@ -22,8 +22,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.base import BaseDocumentTransformer
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Interface for vector stores."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -20,9 +21,13 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import BaseRetriever
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
VST = TypeVar("VST", bound="VectorStore")
|
||||
|
||||
@@ -387,7 +392,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
@@ -405,7 +416,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Wrapper around Redis vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -21,6 +22,10 @@ from typing import (
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -567,7 +572,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
elif self.search_type == "similarity_limit":
|
||||
@@ -578,7 +589,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Wrapper around SingleStore DB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -15,6 +16,10 @@ from typing import (
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
@@ -359,14 +364,26 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
||||
k: int = 4
|
||||
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError(
|
||||
"SingleStoreDBVectorStoreRetriever does not support async"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import itertools
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -11,7 +11,7 @@ import requests
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.arxiv import ArxivLoader
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def assert_docs(docs: List[Document]) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import pandas as pd
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders import DataFrameLoader
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.retrievers.document_compressors import (
|
||||
DocumentCompressorPipeline,
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Integration test for LLMChainExtractor."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_llm_construction_with_kwargs() -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Integration test for llm-based relevant doc filtering."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from langchain.document_transformers import _DocumentWithState
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_embeddings_filter() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers import ArxivRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers import PubMedRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers import WikipediaRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.document_transformers import (
|
||||
_DocumentWithState,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, List
|
||||
import pytest
|
||||
|
||||
from langchain.agents.load_tools import load_tools
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, List
|
||||
import pytest
|
||||
|
||||
from langchain.agents.load_tools import load_tools
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities import PubMedAPIWrapper
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utilities import WikipediaAPIWrapper
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from vcr.request import Request
|
||||
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
# Those environment variables turn on Deep Lake pytest mode.
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores.docarray import DocArrayHnswSearch
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores.docarray import DocArrayInMemorySearch
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from langchain.docstore.arbitrary_fn import DocstoreFn
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_document_found() -> None:
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.document_loaders.parsers.generic import MimeTypeBasedParser
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
class TestMimeBasedParser:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Iterator
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
def test_base_blob_parser() -> None:
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob, FileSystemBlobLoader
|
||||
from langchain.document_loaders.generic import GenericLoader
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
_get_hours_passed,
|
||||
)
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.retrievers import ZepRetriever
|
||||
from langchain.schema.base import Document
|
||||
from langchain.schema.document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import MemorySearchResult, ZepClient
|
||||
|
||||
Reference in New Issue
Block a user