mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 04:58:46 +00:00
cohere[patch]: Fix retriever (#19771)
* Replace `source_documents` with `documents` * Pass `documents` as a named arg vs keyword * Make `parsed_docs` more robust * Fix edge case of doc page_content being `None`
This commit is contained in:
parent
b6ebddbacc
commit
8cf1d75d08
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str:
|
|||||||
def get_cohere_chat_request(
|
def get_cohere_chat_request(
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
*,
|
*,
|
||||||
documents: Optional[List[Dict[str, str]]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
connectors: Optional[List[Dict[str, str]]] = None,
|
connectors: Optional[List[Dict[str, str]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -95,17 +96,25 @@ def get_cohere_chat_request(
|
|||||||
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parsed_docs: Optional[List[Document]] = None
|
||||||
|
if "documents" in additional_kwargs:
|
||||||
|
parsed_docs = (
|
||||||
|
additional_kwargs["documents"]
|
||||||
|
if len(additional_kwargs["documents"]) > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
elif documents is not None and len(documents) > 0:
|
||||||
|
parsed_docs = documents
|
||||||
|
|
||||||
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
||||||
if additional_kwargs.get("documents"):
|
if parsed_docs is not None:
|
||||||
formatted_docs = [
|
formatted_docs = [
|
||||||
{
|
{
|
||||||
"text": doc.page_content,
|
"text": doc.page_content,
|
||||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||||
}
|
}
|
||||||
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
for i, doc in enumerate(parsed_docs)
|
||||||
]
|
]
|
||||||
elif documents:
|
|
||||||
formatted_docs = documents
|
|
||||||
|
|
||||||
# by enabling automatic prompt truncation, the probability of request failure is
|
# by enabling automatic prompt truncation, the probability of request failure is
|
||||||
# reduced with minimal impact on response quality
|
# reduced with minimal impact on response quality
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -17,15 +17,16 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def _get_docs(response: Any) -> List[Document]:
|
def _get_docs(response: Any) -> List[Document]:
|
||||||
docs = (
|
docs = []
|
||||||
[]
|
if (
|
||||||
if "documents" not in response.generation_info
|
"documents" in response.generation_info
|
||||||
or len(response.generation_info["documents"]) == 0
|
and len(response.generation_info["documents"]) > 0
|
||||||
else [
|
):
|
||||||
Document(page_content=doc["snippet"], metadata=doc)
|
for doc in response.generation_info["documents"]:
|
||||||
for doc in response.generation_info["documents"]
|
content = doc.get("snippet", None) or doc.get("text", None)
|
||||||
]
|
if content is not None:
|
||||||
)
|
docs.append(Document(page_content=content, metadata=doc))
|
||||||
|
|
||||||
docs.append(
|
docs.append(
|
||||||
Document(
|
Document(
|
||||||
page_content=response.message.content,
|
page_content=response.message.content,
|
||||||
@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever):
|
|||||||
"""Allow arbitrary types."""
|
"""Allow arbitrary types."""
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||||
res = self.llm.generate(
|
res = self.llm.generate(
|
||||||
messages,
|
messages,
|
||||||
connectors=self.connectors,
|
connectors=self.connectors if documents is None else None,
|
||||||
|
documents=documents,
|
||||||
callbacks=run_manager.get_child(),
|
callbacks=run_manager.get_child(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).generations[0][0]
|
).generations[0][0]
|
||||||
@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever):
|
|||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||||
res = (
|
res = (
|
||||||
await self.llm.agenerate(
|
await self.llm.agenerate(
|
||||||
messages,
|
messages,
|
||||||
connectors=self.connectors,
|
connectors=self.connectors if documents is None else None,
|
||||||
|
documents=documents,
|
||||||
callbacks=run_manager.get_child(),
|
callbacks=run_manager.get_child(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user