mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)
- **Description:** Add support for cohere SDK v5 (keeps v4 backwards compatibility) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
06165efb5b
commit
7253b816cc
@ -40,18 +40,10 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 1,
|
||||||
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
|
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"········\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import getpass\n",
|
"import getpass\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -90,7 +82,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 3,
|
||||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -103,7 +95,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -115,7 +107,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -124,22 +116,22 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content=\"Who's there?\")"
|
"AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"messages = [HumanMessage(content=\"knock knock\")]\n",
|
"messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
|
||||||
"chat.invoke(messages)"
|
"chat.invoke(messages)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -148,10 +140,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content=\"Who's there?\")"
|
"AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -162,7 +154,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -172,7 +164,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Who's there?"
|
"4! It's a pleasure to be of service in this mathematical game."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -183,17 +175,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 8,
|
||||||
"id": "064288e4-f184-4496-9427-bcf148fa055e",
|
"id": "064288e4-f184-4496-9427-bcf148fa055e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"[AIMessage(content=\"Who's there?\")]"
|
"[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -214,7 +206,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"id": "0851b103",
|
"id": "0851b103",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -227,17 +219,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 10,
|
||||||
"id": "ae950c0f-1691-47f1-b609-273033cae707",
|
"id": "ae950c0f-1691-47f1-b609-273033cae707",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")"
|
"AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -263,7 +255,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.11.7"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -10,6 +10,19 @@
|
|||||||
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
|
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2c367be3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import getpass\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@ -218,7 +231,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.7"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
File diff suppressed because one or more lines are too long
@ -80,7 +80,7 @@ def get_cohere_chat_request(
|
|||||||
"AUTO" if documents is not None or connectors is not None else None
|
"AUTO" if documents is not None or connectors is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
req = {
|
||||||
"message": messages[-1].content,
|
"message": messages[-1].content,
|
||||||
"chat_history": [
|
"chat_history": [
|
||||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||||
@ -91,6 +91,8 @@ def get_cohere_chat_request(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {k: v for k, v in req.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
class ChatCohere(BaseChatModel, BaseCohere):
|
class ChatCohere(BaseChatModel, BaseCohere):
|
||||||
"""`Cohere` chat large language models.
|
"""`Cohere` chat large language models.
|
||||||
@ -142,6 +144,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||||
|
|
||||||
|
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
|
||||||
|
stream = self.client.chat_stream(**request)
|
||||||
|
else:
|
||||||
stream = self.client.chat(**request, stream=True)
|
stream = self.client.chat(**request, stream=True)
|
||||||
|
|
||||||
for data in stream:
|
for data in stream:
|
||||||
@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||||
stream = await self.async_client.chat(**request, stream=True)
|
|
||||||
|
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
|
||||||
|
stream = self.async_client.chat_stream(**request)
|
||||||
|
else:
|
||||||
|
stream = self.async_client.chat(**request, stream=True)
|
||||||
|
|
||||||
async for data in stream:
|
async for data in stream:
|
||||||
if data.event_type == "text-generation":
|
if data.event_type == "text-generation":
|
||||||
@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||||
response = self.client.chat(**request, stream=False)
|
response = self.client.chat(**request)
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
message = AIMessage(content=response.text)
|
||||||
generation_info = None
|
generation_info = None
|
||||||
|
@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
from langchain_community.llms.cohere import _create_retry_decorator
|
||||||
|
|
||||||
|
|
||||||
class CohereEmbeddings(BaseModel, Embeddings):
|
class CohereEmbeddings(BaseModel, Embeddings):
|
||||||
"""Cohere embedding models.
|
"""Cohere embedding models.
|
||||||
@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
cohere_api_key: Optional[str] = None
|
cohere_api_key: Optional[str] = None
|
||||||
|
|
||||||
max_retries: Optional[int] = 3
|
max_retries: int = 3
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
request_timeout: Optional[float] = None
|
request_timeout: Optional[float] = None
|
||||||
"""Timeout in seconds for the Cohere API request."""
|
"""Timeout in seconds for the Cohere API request."""
|
||||||
@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
cohere_api_key = get_from_dict_or_env(
|
cohere_api_key = get_from_dict_or_env(
|
||||||
values, "cohere_api_key", "COHERE_API_KEY"
|
values, "cohere_api_key", "COHERE_API_KEY"
|
||||||
)
|
)
|
||||||
max_retries = values.get("max_retries")
|
|
||||||
request_timeout = values.get("request_timeout")
|
request_timeout = values.get("request_timeout")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
client_name = values["user_agent"]
|
client_name = values["user_agent"]
|
||||||
values["client"] = cohere.Client(
|
values["client"] = cohere.Client(
|
||||||
cohere_api_key,
|
cohere_api_key,
|
||||||
max_retries=max_retries,
|
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
client_name=client_name,
|
client_name=client_name,
|
||||||
)
|
)
|
||||||
values["async_client"] = cohere.AsyncClient(
|
values["async_client"] = cohere.AsyncClient(
|
||||||
cohere_api_key,
|
cohere_api_key,
|
||||||
max_retries=max_retries,
|
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
client_name=client_name,
|
client_name=client_name,
|
||||||
)
|
)
|
||||||
@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def embed_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embed call."""
|
||||||
|
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.embed(**kwargs)
|
||||||
|
|
||||||
|
return _embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def aembed_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embed call."""
|
||||||
|
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return await self.async_client.embed(**kwargs)
|
||||||
|
|
||||||
|
return _embed_with_retry(**kwargs)
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
self, texts: List[str], *, input_type: Optional[str] = None
|
self, texts: List[str], *, input_type: Optional[str] = None
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
embeddings = self.client.embed(
|
embeddings = self.embed_with_retry(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
input_type=input_type,
|
input_type=input_type,
|
||||||
@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
|||||||
self, texts: List[str], *, input_type: Optional[str] = None
|
self, texts: List[str], *, input_type: Optional[str] = None
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
embeddings = (
|
embeddings = (
|
||||||
await self.async_client.embed(
|
await self.aembed_with_retry(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
input_type=input_type,
|
input_type=input_type,
|
||||||
|
@ -24,25 +24,32 @@ from langchain_community.llms.utils import enforce_stop_tokens
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
|
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
|
||||||
import cohere
|
import cohere
|
||||||
|
|
||||||
|
# support v4 and v5
|
||||||
|
retry_conditions = (
|
||||||
|
retry_if_exception_type(cohere.error.CohereError)
|
||||||
|
if hasattr(cohere, "error")
|
||||||
|
else retry_if_exception_type(Exception)
|
||||||
|
)
|
||||||
|
|
||||||
min_seconds = 4
|
min_seconds = 4
|
||||||
max_seconds = 10
|
max_seconds = 10
|
||||||
# Wait 2^x * 1 second between each retry starting with
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
return retry(
|
return retry(
|
||||||
reraise=True,
|
reraise=True,
|
||||||
stop=stop_after_attempt(llm.max_retries),
|
stop=stop_after_attempt(max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=(retry_if_exception_type(cohere.error.CohereError)),
|
retry=retry_conditions,
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm.max_retries)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
@ -53,7 +60,7 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
|||||||
|
|
||||||
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm.max_retries)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
Loading…
Reference in New Issue
Block a user