mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) (#19084)
- **Description:** Add support for cohere SDK v5 (keeps v4 backwards compatibility) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
06165efb5b
commit
7253b816cc
@ -40,18 +40,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 1,
|
||||
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
@ -90,7 +82,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -103,7 +95,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 4,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -115,7 +107,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -124,22 +116,22 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Who's there?\")"
|
||||
"AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [HumanMessage(content=\"knock knock\")]\n",
|
||||
"messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
|
||||
"chat.invoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -148,10 +140,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Who's there?\")"
|
||||
"AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -162,7 +154,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -172,7 +164,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Who's there?"
|
||||
"4! It's a pleasure to be of service in this mathematical game."
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -183,17 +175,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"id": "064288e4-f184-4496-9427-bcf148fa055e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content=\"Who's there?\")]"
|
||||
"[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -214,7 +206,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 9,
|
||||
"id": "0851b103",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -227,17 +219,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 10,
|
||||
"id": "ae950c0f-1691-47f1-b609-273033cae707",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")"
|
||||
"AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -263,7 +255,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -10,6 +10,19 @@
|
||||
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c367be3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -218,7 +231,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
File diff suppressed because one or more lines are too long
@ -80,7 +80,7 @@ def get_cohere_chat_request(
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
return {
|
||||
req = {
|
||||
"message": messages[-1].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||
@ -91,6 +91,8 @@ def get_cohere_chat_request(
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return {k: v for k, v in req.items() if v is not None}
|
||||
|
||||
|
||||
class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
@ -142,7 +144,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
|
||||
stream = self.client.chat_stream(**request)
|
||||
else:
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
@ -160,7 +166,11 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
|
||||
stream = self.async_client.chat_stream(**request)
|
||||
else:
|
||||
stream = self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
@ -220,7 +230,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request, stream=False)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
|
@ -4,6 +4,8 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.cohere import _create_retry_decorator
|
||||
|
||||
|
||||
class CohereEmbeddings(BaseModel, Embeddings):
|
||||
"""Cohere embedding models.
|
||||
@ -34,7 +36,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
max_retries: Optional[int] = 3
|
||||
max_retries: int = 3
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[float] = None
|
||||
"""Timeout in seconds for the Cohere API request."""
|
||||
@ -52,7 +54,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
max_retries = values.get("max_retries")
|
||||
request_timeout = values.get("request_timeout")
|
||||
|
||||
try:
|
||||
@ -61,13 +62,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key,
|
||||
max_retries=max_retries,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
)
|
||||
@ -78,10 +77,30 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embed call."""
|
||||
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.embed(**kwargs)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
def aembed_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embed call."""
|
||||
retry_decorator = _create_retry_decorator(self.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
async def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
return await self.async_client.embed(**kwargs)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
def embed(
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = self.client.embed(
|
||||
embeddings = self.embed_with_retry(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
@ -93,7 +112,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
self, texts: List[str], *, input_type: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
embeddings = (
|
||||
await self.async_client.embed(
|
||||
await self.aembed_with_retry(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
|
@ -24,25 +24,32 @@ from langchain_community.llms.utils import enforce_stop_tokens
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
|
||||
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
|
||||
import cohere
|
||||
|
||||
# support v4 and v5
|
||||
retry_conditions = (
|
||||
retry_if_exception_type(cohere.error.CohereError)
|
||||
if hasattr(cohere, "error")
|
||||
else retry_if_exception_type(Exception)
|
||||
)
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(retry_if_exception_type(cohere.error.CohereError)),
|
||||
retry=retry_conditions,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
@ -53,7 +60,7 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
|
||||
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm.max_retries)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
|
Loading…
Reference in New Issue
Block a user