mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 03:27:55 +00:00
Implement async support for Cohere (#8237)
This PR introduces async API support for Cohere, both LLM and embeddings. It requires updating `cohere` package to `^4`. Tagging @hwchase17, @baskaryan, @agola11 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
bf1357f584
commit
c5988c1d4b
@ -9,7 +9,7 @@
|
||||
"\n",
|
||||
"LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n",
|
||||
"\n",
|
||||
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n",
|
||||
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI`, `Anthropic` and `Cohere` are supported, but async support for other LLMs is on the roadmap.\n",
|
||||
"\n",
|
||||
"You can use the `agenerate` method to call an OpenAI LLM asynchronously."
|
||||
]
|
||||
@ -56,7 +56,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"I'm doing well, thank you. How about you?\n",
|
||||
"\u001b[1mConcurrent executed in 1.39 seconds.\u001b[0m\n",
|
||||
"\u001B[1mConcurrent executed in 1.39 seconds.\u001B[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"I'm doing well, thank you. How about you?\n",
|
||||
@ -86,7 +86,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"I'm doing well, thanks for asking. How about you?\n",
|
||||
"\u001b[1mSerial executed in 5.77 seconds.\u001b[0m\n"
|
||||
"\u001B[1mSerial executed in 5.77 seconds.\u001B[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -24,6 +24,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
client: Any #: :meta private:
|
||||
"""Cohere client."""
|
||||
async_client: Any #: :meta private:
|
||||
"""Cohere async client."""
|
||||
model: str = "embed-english-v2.0"
|
||||
"""Model name to use."""
|
||||
|
||||
@ -47,6 +49,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import cohere python package. "
|
||||
@ -68,6 +71,20 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = await self.async_client.embed(
|
||||
model=self.model, texts=texts, truncate=self.truncate
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
@ -77,7 +94,16 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self.client.embed(
|
||||
model=self.model, texts=[text], truncate=self.truncate
|
||||
).embeddings[0]
|
||||
return list(map(float, embedding))
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
@ -12,7 +12,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@ -47,6 +50,17 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await llm.async_client.generate(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class Cohere(LLM):
|
||||
"""Cohere large language models.
|
||||
|
||||
@ -62,6 +76,7 @@ class Cohere(LLM):
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
async_client: Any #: :meta private:
|
||||
model: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
|
||||
@ -109,6 +124,7 @@ class Cohere(LLM):
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
@ -139,6 +155,24 @@ class Cohere(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cohere"
|
||||
|
||||
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop_sequences"] = self.stop
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||
text = response.generations[0].text
|
||||
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -160,20 +194,37 @@ class Cohere(LLM):
|
||||
|
||||
response = cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop_sequences"] = self.stop
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
params = {**params, **kwargs}
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = completion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
text = response.generations[0].text
|
||||
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
if stop is not None or self.stop is not None:
|
||||
text = enforce_stop_tokens(text, params["stop_sequences"])
|
||||
return text
|
||||
_stop = params.get("stop_sequences")
|
||||
return self._process_response(response, _stop)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Async call out to Cohere's generate endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = await acompletion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
return self._process_response(response, _stop)
|
||||
|
66
libs/langchain/poetry.lock
generated
66
libs/langchain/poetry.lock
generated
@ -1863,18 +1863,23 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
|
||||
|
||||
[[package]]
|
||||
name = "cohere"
|
||||
version = "3.10.0"
|
||||
description = "A Python library for the Cohere API"
|
||||
version = "4.18.0"
|
||||
description = ""
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.6"
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "cohere-3.10.0.tar.gz", hash = "sha256:8c06a87a47aa9521051eeba130ce391d84ab578148c4ea5b62f6dcc41bd3a274"},
|
||||
{file = "cohere-4.18.0-py3-none-any.whl", hash = "sha256:26b5be3f93c0046be7fd89b2e724190e10f9fceac8bcf8f22581368a1f3af2e4"},
|
||||
{file = "cohere-4.18.0.tar.gz", hash = "sha256:ed3d5703384412312fd827e669364b2f0eb3678a1206987cb3e1d98b88409c31"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = "*"
|
||||
urllib3 = ">=1.26,<2.0"
|
||||
aiohttp = ">=3.0,<4.0"
|
||||
backoff = ">=2.0,<3.0"
|
||||
fastavro = "1.7.4"
|
||||
importlib_metadata = ">=6.0,<7.0"
|
||||
requests = ">=2.25.0,<3.0.0"
|
||||
urllib3 = ">=1.26,<3"
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
@ -2689,6 +2694,53 @@ dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>
|
||||
doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
|
||||
test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "fastavro"
|
||||
version = "1.7.4"
|
||||
description = "Fast read/write of AVRO files"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "fastavro-1.7.4-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7568e621b94e061974b2a96d70670d09910e0a71482dd8610b153c07bd768497"},
|
||||
{file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4ec994faf64b743647f0027fcc56b01dc15d46c0e48fa15828277cb02dbdcd6"},
|
||||
{file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:727fdc1ddd12fcc6addab0b6df12ef999a6babe4b753db891f78aa2ee33edc77"},
|
||||
{file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b2f0cb3f7795fcb0042e0bbbe51204c28338a455986d68409b26dcbde64dd69a"},
|
||||
{file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bb0a8b5016a99be4b8ce3550889a1bd968c0fb3f521bcfbae24210c6342aee0c"},
|
||||
{file = "fastavro-1.7.4-cp310-cp310-win_amd64.whl", hash = "sha256:1d2040b2bf3dc1a75170ea44d1e7e09f84fb77f40ef2e6c6b9f2eaf710557083"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5542423f46bb7fc9699c467cbf151c2713aa6976ef14f4f5ec3532d80d0bb616"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec396e6ab6b272708c8b9a0142df01fff4c7a1f168050f292ab92fdaee0b0257"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b10d68c03371b79f461feca1c6c7e9d3f6aea2e9c7472b25cd749c57562aa1"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f94d5168ec72f3cfcf2181df1c46ad240dc1fcf361717447d2c5237121b9df55"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bad3dc279ed4ce747989259035cb3607f189ef7aff40339202f9321ca7f83d0b"},
|
||||
{file = "fastavro-1.7.4-cp311-cp311-win_amd64.whl", hash = "sha256:8480ff444d9c7abd0bf121dd68656bd2115caca8ed28e71936eff348fde706e0"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:bd3d669f4ec6915c88bb80b7c14e01d2c3ceb93a61de5dcf33ff13972bba505e"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a312b128536b81bdb79f27076f513b998abe7d13ee6fe52e99bc01f7ad9b06a"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:487054d1419f1bfa41e7f19c718cbdbbb254319d3fd5b9ac411054d6432b9d40"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d2897fe7d1d5b27dcd33c43d68480de36e55a0e651d7731004a36162cd3eed9e"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6d318b49fd648a1fd93394411fe23761b486ac65dadea7c52dbeb0d0bef30221"},
|
||||
{file = "fastavro-1.7.4-cp37-cp37m-win_amd64.whl", hash = "sha256:a117c3b122a8110c6ab99b3e66736790b4be19ceefb1edf0e732c33b3dc411c8"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:0cca15e1a1f829e40524004342e425acfb594cefbd3388b0a5d13542750623ac"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9211ec7a18a46a2aee01a2a979fd79f05f36b11fdb1bc469c9d9fd8cec32579"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f16bde6b5fb51e15233bfcee0378f48d4221201ba45e497a8063f6d216b7aad7"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aeca55c905ff4c667f2158564654a778918988811ae3eb28592767edcf5f5c4a"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b244f3abc024fc043d6637284ba2ffee5a1291c08a0f361ea1af4d829f66f303"},
|
||||
{file = "fastavro-1.7.4-cp38-cp38-win_amd64.whl", hash = "sha256:b64e394c87cb99d0681727e1ae5d3633906a72abeab5ea0c692394aeb5a56607"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:8c8115bdb1c862354d9abd0ea23eab85793bbff139087f2607bd4b83e8ae07ab"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b27dd08f2338a478185c6ba23308002f334642ce83a6aeaf8308271efef88062"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f087c246afab8bac08d86ef21be87cbf4f3779348fb960c081863fc3d570412c"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b4077e17a2bab37af96e5ca52e61b6f2b85e4577e7a2903f6814642eb6a834f7"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:776511cecf2ea9da4edd0de5015c1562cd9063683cf94f79bc9e20bab8f06923"},
|
||||
{file = "fastavro-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:a7ea5565fe2c145e074ce9ba75fafd5479a86b34a8dbd00dd1835cf192290e14"},
|
||||
{file = "fastavro-1.7.4.tar.gz", hash = "sha256:6450f47ac4db95ec3a9e6434fec1f8a3c4c8c941de16205832ca8c67dd23d0d2"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
codecs = ["lz4", "python-snappy", "zstandard"]
|
||||
lz4 = ["lz4"]
|
||||
snappy = ["python-snappy"]
|
||||
zstandard = ["zstandard"]
|
||||
|
||||
[[package]]
|
||||
name = "fastjsonschema"
|
||||
version = "2.17.1"
|
||||
@ -12500,4 +12552,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "7a8847de4dd88e71b423ff148823523220a5649340178e8ab1f7bafb03a290d2"
|
||||
content-hash = "4f5d91f450555bb3a039c3aef4a7996d1322f25608ec17a7b0c1ad92813d6a63"
|
||||
|
@ -47,7 +47,7 @@ qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
|
||||
dataclasses-json = "^0.5.7"
|
||||
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}
|
||||
tenacity = "^8.1.0"
|
||||
cohere = {version = "^3", optional = true}
|
||||
cohere = {version = "^4", optional = true}
|
||||
openai = {version = "^0", optional = true}
|
||||
nlpcloud = {version = "^1", optional = true}
|
||||
nomic = {version = "^1.0.43", optional = true}
|
||||
|
Loading…
Reference in New Issue
Block a user