Files
langchain/libs/community/langchain_community/retrievers/cohere_rag_retriever.py
Harrison Chase 8516a03a02 langchain-community[major]: Upgrade community to pydantic 2 (#26011)
This PR upgrades langchain-community to pydantic 2.


* Most of this PR was auto-generated using code mods with gritql
(https://github.com/eyurtsev/migrate-pydantic/tree/main)
* Subsequently, some code was fixed manually due to accommodate
differences between pydantic 1 and 2

Breaking Changes:

- Use TEXTEMBED_API_KEY and TEXTEMBEB_API_URL for env variables for text
embed integrations:
cbea780492

Other changes:

- Added pydantic_settings as a required dependency for community. This
may be removed if we have enough time to convert the dependency into an
optional one.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-09-05 14:07:10 -04:00

97 lines
2.9 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, Field
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
)
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs
@deprecated(
since="0.0.30",
removal="1.0",
alternative_import="langchain_cohere.CohereRagRetriever",
)
class CohereRagRetriever(BaseRetriever):
"""Cohere Chat API with RAG."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
When specified, the model's reply will be enriched with information found by
querying each of the connectors (RAG). These will be returned as langchain
documents.
Currently only accepts {"id": "web-search"}.
"""
llm: BaseChatModel
"""Cohere ChatModel to use."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
)
).generations[0][0]
return _get_docs(res)