community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)

- **Description:** Add support for cohere SDK v5 (keeps v4 backwards
compatibility)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
billytrend-cohere 2024-03-14 17:53:24 -05:00 committed by GitHub
parent 06165efb5b
commit 7253b816cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 101 additions and 47 deletions

View File

@ -40,18 +40,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 1,
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7", "id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"········\n"
]
}
],
"source": [ "source": [
"import getpass\n", "import getpass\n",
"import os\n", "import os\n",
@ -90,7 +82,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 3,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -103,7 +95,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 4,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -115,7 +107,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -124,22 +116,22 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"AIMessage(content=\"Who's there?\")" "AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})"
] ]
}, },
"execution_count": 3, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"messages = [HumanMessage(content=\"knock knock\")]\n", "messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
"chat.invoke(messages)" "chat.invoke(messages)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -148,10 +140,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"AIMessage(content=\"Who's there?\")" "AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})"
] ]
}, },
"execution_count": 4, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -162,7 +154,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34", "id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -172,7 +164,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Who's there?" "4! It's a pleasure to be of service in this mathematical game."
] ]
} }
], ],
@ -183,17 +175,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 8,
"id": "064288e4-f184-4496-9427-bcf148fa055e", "id": "064288e4-f184-4496-9427-bcf148fa055e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[AIMessage(content=\"Who's there?\")]" "[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]"
] ]
}, },
"execution_count": 6, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -214,7 +206,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 9,
"id": "0851b103", "id": "0851b103",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -227,17 +219,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 10,
"id": "ae950c0f-1691-47f1-b609-273033cae707", "id": "ae950c0f-1691-47f1-b609-273033cae707",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")" "AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})"
] ]
}, },
"execution_count": 8, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -263,7 +255,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.11.7"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -10,6 +10,19 @@
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own." "This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "2c367be3",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -218,7 +231,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.11.7"
} }
}, },
"nbformat": 4, "nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@ -80,7 +80,7 @@ def get_cohere_chat_request(
"AUTO" if documents is not None or connectors is not None else None "AUTO" if documents is not None or connectors is not None else None
) )
return { req = {
"message": messages[-1].content, "message": messages[-1].content,
"chat_history": [ "chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1] {"role": get_role(x), "message": x.content} for x in messages[:-1]
@ -91,6 +91,8 @@ def get_cohere_chat_request(
**kwargs, **kwargs,
} }
return {k: v for k, v in req.items() if v is not None}
class ChatCohere(BaseChatModel, BaseCohere): class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models. """`Cohere` chat large language models.
@ -142,6 +144,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs) request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True) stream = self.client.chat(**request, stream=True)
for data in stream: for data in stream:
@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs) request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)
async for data in stream: async for data in stream:
if data.event_type == "text-generation": if data.event_type == "text-generation":
@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs) request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False) response = self.client.chat(**request)
message = AIMessage(content=response.text) message = AIMessage(content=response.text)
generation_info = None generation_info = None

View File

@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.cohere import _create_retry_decorator
class CohereEmbeddings(BaseModel, Embeddings): class CohereEmbeddings(BaseModel, Embeddings):
"""Cohere embedding models. """Cohere embedding models.
@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key: Optional[str] = None cohere_api_key: Optional[str] = None
max_retries: Optional[int] = 3 max_retries: int = 3
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request.""" """Timeout in seconds for the Cohere API request."""
@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env( cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY" values, "cohere_api_key", "COHERE_API_KEY"
) )
max_retries = values.get("max_retries")
request_timeout = values.get("request_timeout") request_timeout = values.get("request_timeout")
try: try:
@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
client_name = values["user_agent"] client_name = values["user_agent"]
values["client"] = cohere.Client( values["client"] = cohere.Client(
cohere_api_key, cohere_api_key,
max_retries=max_retries,
timeout=request_timeout, timeout=request_timeout,
client_name=client_name, client_name=client_name,
) )
values["async_client"] = cohere.AsyncClient( values["async_client"] = cohere.AsyncClient(
cohere_api_key, cohere_api_key,
max_retries=max_retries,
timeout=request_timeout, timeout=request_timeout,
client_name=client_name, client_name=client_name,
) )
@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
) )
return values return values
def embed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
return self.client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def aembed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)
@retry_decorator
async def _embed_with_retry(**kwargs: Any) -> Any:
return await self.async_client.embed(**kwargs)
return _embed_with_retry(**kwargs)
def embed( def embed(
self, texts: List[str], *, input_type: Optional[str] = None self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]: ) -> List[List[float]]:
embeddings = self.client.embed( embeddings = self.embed_with_retry(
model=self.model, model=self.model,
texts=texts, texts=texts,
input_type=input_type, input_type=input_type,
@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
self, texts: List[str], *, input_type: Optional[str] = None self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]: ) -> List[List[float]]:
embeddings = ( embeddings = (
await self.async_client.embed( await self.aembed_with_retry(
model=self.model, model=self.model,
texts=texts, texts=texts,
input_type=input_type, input_type=input_type,

View File

@ -24,25 +24,32 @@ from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]: def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
import cohere import cohere
# support v4 and v5
retry_conditions = (
retry_if_exception_type(cohere.error.CohereError)
if hasattr(cohere, "error")
else retry_if_exception_type(Exception)
)
min_seconds = 4 min_seconds = 4
max_seconds = 10 max_seconds = 10
# Wait 2^x * 1 second between each retry starting with # Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry( return retry(
reraise=True, reraise=True,
stop=stop_after_attempt(llm.max_retries), stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.error.CohereError)), retry=retry_conditions,
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm.max_retries)
@retry_decorator @retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
@ -53,7 +60,7 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any: def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm.max_retries)
@retry_decorator @retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: async def _completion_with_retry(**kwargs: Any) -> Any: