mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 12:09:58 +00:00
cohere[patch]: Fix retriever (#19771)
* Replace `source_documents` with `documents` * Pass `documents` as a named arg vs keyword * Make `parsed_docs` more robust * Fix edge case of doc page_content being `None`
This commit is contained in:
parent
b6ebddbacc
commit
8cf1d75d08
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str:
|
||||
def get_cohere_chat_request(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
@ -95,17 +96,25 @@ def get_cohere_chat_request(
|
||||
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
||||
)
|
||||
|
||||
parsed_docs: Optional[List[Document]] = None
|
||||
if "documents" in additional_kwargs:
|
||||
parsed_docs = (
|
||||
additional_kwargs["documents"]
|
||||
if len(additional_kwargs["documents"]) > 0
|
||||
else None
|
||||
)
|
||||
elif documents is not None and len(documents) > 0:
|
||||
parsed_docs = documents
|
||||
|
||||
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
||||
if additional_kwargs.get("documents"):
|
||||
if parsed_docs is not None:
|
||||
formatted_docs = [
|
||||
{
|
||||
"text": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
||||
for i, doc in enumerate(parsed_docs)
|
||||
]
|
||||
elif documents:
|
||||
formatted_docs = documents
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -17,15 +17,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _get_docs(response: Any) -> List[Document]:
|
||||
docs = (
|
||||
[]
|
||||
if "documents" not in response.generation_info
|
||||
or len(response.generation_info["documents"]) == 0
|
||||
else [
|
||||
Document(page_content=doc["snippet"], metadata=doc)
|
||||
for doc in response.generation_info["documents"]
|
||||
]
|
||||
)
|
||||
docs = []
|
||||
if (
|
||||
"documents" in response.generation_info
|
||||
and len(response.generation_info["documents"]) > 0
|
||||
):
|
||||
for doc in response.generation_info["documents"]:
|
||||
content = doc.get("snippet", None) or doc.get("text", None)
|
||||
if content is not None:
|
||||
docs.append(Document(page_content=content, metadata=doc))
|
||||
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=response.message.content,
|
||||
@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever):
|
||||
"""Allow arbitrary types."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = self.llm.generate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
connectors=self.connectors if documents is None else None,
|
||||
documents=documents,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
).generations[0][0]
|
||||
@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = (
|
||||
await self.llm.agenerate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
connectors=self.connectors if documents is None else None,
|
||||
documents=documents,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user