mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)
- **Description:** Add support for cohere SDK v5 (keeps v4 backwards compatibility) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
06165efb5b
commit
7253b816cc
@@ -80,7 +80,7 @@ def get_cohere_chat_request(
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
return {
|
||||
req = {
|
||||
"message": messages[-1].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||
@@ -91,6 +91,8 @@ def get_cohere_chat_request(
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return {k: v for k, v in req.items() if v is not None}
|
||||
|
||||
|
||||
class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
@@ -142,7 +144,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
|
||||
stream = self.client.chat_stream(**request)
|
||||
else:
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
@@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
|
||||
stream = self.async_client.chat_stream(**request)
|
||||
else:
|
||||
stream = self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
@@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request, stream=False)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
|
Reference in New Issue
Block a user