mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List
|
|
|
|
from langchain_core._api.deprecation import deprecated
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
CallbackManagerForRetrieverRun,
|
|
)
|
|
from langchain_core.documents import Document
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from pydantic import ConfigDict, Field
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
def _get_docs(response: Any) -> List[Document]:
|
|
docs = (
|
|
[]
|
|
if "documents" not in response.generation_info
|
|
else [
|
|
Document(page_content=doc["snippet"], metadata=doc)
|
|
for doc in response.generation_info["documents"]
|
|
]
|
|
)
|
|
docs.append(
|
|
Document(
|
|
page_content=response.message.content,
|
|
metadata={
|
|
"type": "model_response",
|
|
"citations": response.generation_info["citations"],
|
|
"search_results": response.generation_info["search_results"],
|
|
"search_queries": response.generation_info["search_queries"],
|
|
"token_count": response.generation_info["token_count"],
|
|
},
|
|
)
|
|
)
|
|
return docs
|
|
|
|
|
|
@deprecated(
|
|
since="0.0.30",
|
|
removal="1.0",
|
|
alternative_import="langchain_cohere.CohereRagRetriever",
|
|
)
|
|
class CohereRagRetriever(BaseRetriever):
|
|
"""Cohere Chat API with RAG."""
|
|
|
|
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
|
|
"""
|
|
When specified, the model's reply will be enriched with information found by
|
|
querying each of the connectors (RAG). These will be returned as langchain
|
|
documents.
|
|
|
|
Currently only accepts {"id": "web-search"}.
|
|
"""
|
|
|
|
llm: BaseChatModel
|
|
"""Cohere ChatModel to use."""
|
|
|
|
model_config = ConfigDict(
|
|
arbitrary_types_allowed=True,
|
|
)
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
|
) -> List[Document]:
|
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
|
res = self.llm.generate(
|
|
messages,
|
|
connectors=self.connectors,
|
|
callbacks=run_manager.get_child(),
|
|
**kwargs,
|
|
).generations[0][0]
|
|
return _get_docs(res)
|
|
|
|
async def _aget_relevant_documents(
|
|
self,
|
|
query: str,
|
|
*,
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
|
res = (
|
|
await self.llm.agenerate(
|
|
messages,
|
|
connectors=self.connectors,
|
|
callbacks=run_manager.get_child(),
|
|
**kwargs,
|
|
)
|
|
).generations[0][0]
|
|
return _get_docs(res)
|