community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)

- **Description:** Add support for cohere SDK v5 (keeps v4 backwards
compatibility)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
billytrend-cohere
2024-03-14 17:53:24 -05:00
committed by GitHub
parent 06165efb5b
commit 7253b816cc
6 changed files with 101 additions and 47 deletions

View File

@@ -80,7 +80,7 @@ def get_cohere_chat_request(
"AUTO" if documents is not None or connectors is not None else None
)
return {
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
@@ -91,6 +91,8 @@ def get_cohere_chat_request(
**kwargs,
}
return {k: v for k, v in req.items() if v is not None}
class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
@@ -142,7 +144,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)
for data in stream:
if data.event_type == "text-generation":
@@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)
async for data in stream:
if data.event_type == "text-generation":
@@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
return await agenerate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None