mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
Add cohere /chat integration (#11389)
Add cohere /chat integration and an iPython notebook to demonstrate the addition. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ca346011b7
commit
2ff91a46c0
174
docs/extras/integrations/chat/cohere.ipynb
Normal file
174
docs/extras/integrations/chat/cohere.ipynb
Normal file
@ -0,0 +1,174 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cohere\n",
|
||||
"\n",
|
||||
"This notebook covers how to get started with Cohere chat models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatCohere\n",
|
||||
"from langchain.schema import AIMessage, HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatCohere()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Who's there?\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"knock knock\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `ChatCohere` also supports async and streaming functionality:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.manager import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Who's there?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[ChatGenerationChunk(text=\"Who's there?\", message=AIMessageChunk(content=\"Who's there?\"))]], llm_output={}, run=[RunInfo(run_id=UUID('1e9eaefc-9c99-4fa9-8297-ef9975d4751e'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await chat.agenerate([messages])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Who's there?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessageChunk(content=\"Who's there?\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatCohere(\n",
|
||||
" streaming=True,\n",
|
||||
" verbose=True,\n",
|
||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||
")\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -22,6 +22,7 @@ from langchain.chat_models.anyscale import ChatAnyscale
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.cohere import ChatCohere
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.fireworks import ChatFireworks
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatAnthropic",
|
||||
"ChatCohere",
|
||||
"ChatGooglePalm",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatOllama",
|
||||
|
162
libs/langchain/langchain/chat_models/cohere.py
Normal file
162
libs/langchain/langchain/chat_models/cohere.py
Normal file
@ -0,0 +1,162 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.cohere import BaseCohere
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "User"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "Chatbot"
|
||||
elif isinstance(message, SystemMessage):
|
||||
return "System"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatCohere
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
chat = ChatCohere(model="foo")
|
||||
result = chat([HumanMessage(content="Hello")])
|
||||
print(result.content)
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "cohere-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
def get_cohere_chat_request(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"message": messages[0].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||
],
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
response = self.client.chat(**request, stream=False)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
return len(self.client.tokenize(text).tokens)
|
@ -17,7 +17,8 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -61,7 +62,42 @@ def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class Cohere(LLM):
|
||||
class BaseCohere(Serializable):
|
||||
client: Any #: :meta private:
|
||||
async_client: Any #: :meta private:
|
||||
model: Optional[str] = Field(default=None)
|
||||
"""Model name to use."""
|
||||
|
||||
temperature: float = 0.75
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
else:
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
return values
|
||||
|
||||
|
||||
class Cohere(LLM, BaseCohere):
|
||||
"""Cohere large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
@ -72,20 +108,13 @@ class Cohere(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Cohere
|
||||
|
||||
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
async_client: Any #: :meta private:
|
||||
model: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: float = 0.75
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
k: int = 0
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
@ -105,33 +134,11 @@ class Cohere(LLM):
|
||||
max_retries: int = 10
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import cohere
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
values["async_client"] = cohere.AsyncClient(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
@ -145,6 +152,10 @@ class Cohere(LLM):
|
||||
"truncate": self.truncate,
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"cohere_api_key": "COHERE_API_KEY"}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
Loading…
Reference in New Issue
Block a user