mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Add cohere /chat integration (#11389)
Add cohere /chat integration and an iPython notebook to demonstrate the addition. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ca346011b7
commit
2ff91a46c0
174
docs/extras/integrations/chat/cohere.ipynb
Normal file
174
docs/extras/integrations/chat/cohere.ipynb
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cohere\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook covers how to get started with Cohere chat models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 54,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ChatCohere\n",
|
||||||
|
"from langchain.schema import AIMessage, HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 55,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatCohere()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 56,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"Who's there?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 56,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"knock knock\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## `ChatCohere` also supports async and streaming functionality:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 57,
|
||||||
|
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.manager import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 64,
|
||||||
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Who's there?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[ChatGenerationChunk(text=\"Who's there?\", message=AIMessageChunk(content=\"Who's there?\"))]], llm_output={}, run=[RunInfo(run_id=UUID('1e9eaefc-9c99-4fa9-8297-ef9975d4751e'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 64,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.agenerate([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 63,
|
||||||
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Who's there?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessageChunk(content=\"Who's there?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 63,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatCohere(\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
|
")\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -22,6 +22,7 @@ from langchain.chat_models.anyscale import ChatAnyscale
|
|||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||||
from langchain.chat_models.bedrock import BedrockChat
|
from langchain.chat_models.bedrock import BedrockChat
|
||||||
|
from langchain.chat_models.cohere import ChatCohere
|
||||||
from langchain.chat_models.ernie import ErnieBotChat
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.chat_models.fireworks import ChatFireworks
|
from langchain.chat_models.fireworks import ChatFireworks
|
||||||
@ -45,6 +46,7 @@ __all__ = [
|
|||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
"PromptLayerChatOpenAI",
|
"PromptLayerChatOpenAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
|
"ChatCohere",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
|
162
libs/langchain/langchain/chat_models/cohere.py
Normal file
162
libs/langchain/langchain/chat_models/cohere.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import (
|
||||||
|
BaseChatModel,
|
||||||
|
_agenerate_from_stream,
|
||||||
|
_generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain.llms.cohere import BaseCohere
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
def get_role(message: BaseMessage) -> str:
|
||||||
|
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||||
|
return "User"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
return "Chatbot"
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
return "System"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCohere(BaseChatModel, BaseCohere):
|
||||||
|
"""`Cohere` chat large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``cohere`` python package installed, and the
|
||||||
|
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatCohere
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
chat = ChatCohere(model="foo")
|
||||||
|
result = chat([HumanMessage(content="Hello")])
|
||||||
|
print(result.content)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "cohere-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling Cohere API."""
|
||||||
|
return {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model": self.model}, **self._default_params}
|
||||||
|
|
||||||
|
def get_cohere_chat_request(
|
||||||
|
self, messages: List[BaseMessage], **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"message": messages[0].content,
|
||||||
|
"chat_history": [
|
||||||
|
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||||
|
],
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
stream = self.client.chat(**request, stream=True)
|
||||||
|
|
||||||
|
for data in stream:
|
||||||
|
if data.event_type == "text-generation":
|
||||||
|
delta = data.text
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
stream = await self.async_client.chat(**request, stream=True)
|
||||||
|
|
||||||
|
async for data in stream:
|
||||||
|
if data.event_type == "text-generation":
|
||||||
|
delta = data.text
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return _generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
response = self.client.chat(**request)
|
||||||
|
|
||||||
|
message = AIMessage(content=response.text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await _agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
response = self.client.chat(**request, stream=False)
|
||||||
|
|
||||||
|
message = AIMessage(content=response.text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Calculate number of tokens."""
|
||||||
|
return len(self.client.tokenize(text).tokens)
|
@ -17,7 +17,8 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -61,7 +62,42 @@ def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
|||||||
return _completion_with_retry(**kwargs)
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Cohere(LLM):
|
class BaseCohere(Serializable):
|
||||||
|
client: Any #: :meta private:
|
||||||
|
async_client: Any #: :meta private:
|
||||||
|
model: Optional[str] = Field(default=None)
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
temperature: float = 0.75
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
cohere_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
streaming: bool = Field(default=False)
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import cohere python package. "
|
||||||
|
"Please install it with `pip install cohere`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cohere_api_key = get_from_dict_or_env(
|
||||||
|
values, "cohere_api_key", "COHERE_API_KEY"
|
||||||
|
)
|
||||||
|
values["client"] = cohere.Client(cohere_api_key)
|
||||||
|
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere(LLM, BaseCohere):
|
||||||
"""Cohere large language models.
|
"""Cohere large language models.
|
||||||
|
|
||||||
To use, you should have the ``cohere`` python package installed, and the
|
To use, you should have the ``cohere`` python package installed, and the
|
||||||
@ -72,20 +108,13 @@ class Cohere(LLM):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import Cohere
|
from langchain.llms import Cohere
|
||||||
|
|
||||||
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
|
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
|
||||||
async_client: Any #: :meta private:
|
|
||||||
model: Optional[str] = None
|
|
||||||
"""Model name to use."""
|
|
||||||
|
|
||||||
max_tokens: int = 256
|
max_tokens: int = 256
|
||||||
"""Denotes the number of tokens to predict per generation."""
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
temperature: float = 0.75
|
|
||||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
|
||||||
|
|
||||||
k: int = 0
|
k: int = 0
|
||||||
"""Number of most likely tokens to consider at each step."""
|
"""Number of most likely tokens to consider at each step."""
|
||||||
|
|
||||||
@ -105,33 +134,11 @@ class Cohere(LLM):
|
|||||||
max_retries: int = 10
|
max_retries: int = 10
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
cohere_api_key: Optional[str] = None
|
|
||||||
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate that api key and python package exists in environment."""
|
|
||||||
cohere_api_key = get_from_dict_or_env(
|
|
||||||
values, "cohere_api_key", "COHERE_API_KEY"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import cohere
|
|
||||||
|
|
||||||
values["client"] = cohere.Client(cohere_api_key)
|
|
||||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import cohere python package. "
|
|
||||||
"Please install it with `pip install cohere`."
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling Cohere API."""
|
"""Get the default parameters for calling Cohere API."""
|
||||||
@ -145,6 +152,10 @@ class Cohere(LLM):
|
|||||||
"truncate": self.truncate,
|
"truncate": self.truncate,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"cohere_api_key": "COHERE_API_KEY"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
Loading…
Reference in New Issue
Block a user