community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)

- **Description:** Add support for cohere SDK v5 (keeps v4 backwards
compatibility)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
billytrend-cohere 2024-03-14 17:53:24 -05:00 committed by GitHub
parent 06165efb5b
commit 7253b816cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 101 additions and 47 deletions

View File

@ -40,18 +40,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"········\n"
]
}
],
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
@ -90,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
@ -103,7 +95,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
@ -115,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
@ -124,22 +116,22 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"Who's there?\")"
"AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})"
]
},
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [HumanMessage(content=\"knock knock\")]\n",
"messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
@ -148,10 +140,10 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"Who's there?\")"
"AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})"
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -162,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
@ -172,7 +164,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Who's there?"
"4! It's a pleasure to be of service in this mathematical game."
]
}
],
@ -183,17 +175,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "064288e4-f184-4496-9427-bcf148fa055e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=\"Who's there?\")]"
"[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -214,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "0851b103",
"metadata": {},
"outputs": [],
@ -227,17 +219,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "ae950c0f-1691-47f1-b609-273033cae707",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")"
"AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})"
]
},
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -263,7 +255,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.7"
}
},
"nbformat": 4,

View File

@ -10,6 +10,19 @@
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c367be3",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -218,7 +231,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.7"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@ -80,7 +80,7 @@ def get_cohere_chat_request(
"AUTO" if documents is not None or connectors is not None else None
)
return {
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
@ -91,6 +91,8 @@ def get_cohere_chat_request(
**kwargs,
}
return {k: v for k, v in req.items() if v is not None}
class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
@ -142,7 +144,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)
for data in stream:
if data.event_type == "text-generation":
@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)
async for data in stream:
if data.event_type == "text-generation":
@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
return await agenerate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None

View File

@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.cohere import _create_retry_decorator
class CohereEmbeddings(BaseModel, Embeddings):
"""Cohere embedding models.
@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key: Optional[str] = None
max_retries: Optional[int] = 3
max_retries: int = 3
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
max_retries = values.get("max_retries")
request_timeout = values.get("request_timeout")
try:
@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
client_name = values["user_agent"]
values["client"] = cohere.Client(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
)
return values
def embed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
return self.client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def aembed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
async def _embed_with_retry(**kwargs: Any) -> Any:
return await self.async_client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def embed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = self.client.embed(
embeddings = self.embed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,
@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = (
await self.async_client.embed(
await self.aembed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,

View File

@ -24,25 +24,32 @@ from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
import cohere
# support v4 and v5
retry_conditions = (
retry_if_exception_type(cohere.error.CohereError)
if hasattr(cohere, "error")
else retry_if_exception_type(Exception)
)
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.error.CohereError)),
retry=retry_conditions,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm.max_retries)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
@ -53,7 +60,7 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm.max_retries)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: