From 392f1b32188d40e45adabb85a1641780eedd006b Mon Sep 17 00:00:00 2001 From: Mike Lambert Date: Fri, 14 Apr 2023 18:09:07 -0400 Subject: [PATCH] Add Anthropic ChatModel to langchain (#2293) * Adds an Anthropic ChatModel * Factors out common code in our LLMModel and ChatModel * Supports streaming llm-tokens to the callbacks on a delta basis (until a future V2 API does that for us) * Some fixes --- langchain/chat_models/anthropic.py | 145 +++++++++++++++ langchain/llms/anthropic.py | 172 +++++++++--------- .../chat_models/test_anthropic.py | 81 +++++++++ .../integration_tests/llms/test_anthropic.py | 2 - 4 files changed, 316 insertions(+), 84 deletions(-) create mode 100644 langchain/chat_models/anthropic.py create mode 100644 tests/integration_tests/chat_models/test_anthropic.py diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py new file mode 100644 index 00000000000..b63fbf052d9 --- /dev/null +++ b/langchain/chat_models/anthropic.py @@ -0,0 +1,145 @@ +from typing import List, Optional + +from pydantic import Extra + +from langchain.chat_models.base import BaseChatModel +from langchain.llms.anthropic import _AnthropicCommon +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + HumanMessage, + SystemMessage, +) + + +class ChatAnthropic(BaseChatModel, _AnthropicCommon): + r"""Wrapper around Anthropic's large language model. + + To use, you should have the ``anthropic`` python package installed, and the + environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + import anthropic + from langchain.llms import Anthropic + model = Anthropic(model="", anthropic_api_key="my-api-key") + + # Simplest invocation, automatically wrapped with HUMAN_PROMPT + # and AI_PROMPT. + response = model("What are the biggest risks facing humanity?") + + # Or if you want to use the chat mode, build a few-shot-prompt, or + # put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT: + raw_prompt = "What are the biggest risks facing humanity?" + prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" + response = model(prompt) + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "anthropic-chat" + + def _convert_one_message_to_text(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"{self.HUMAN_PROMPT} {message.content}" + elif isinstance(message, AIMessage): + message_text = f"{self.AI_PROMPT} {message.content}" + elif isinstance(message, SystemMessage): + message_text = f"{self.HUMAN_PROMPT} {message.content}" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str: + """Format a list of strings into a single string with necessary newlines. + + Args: + messages (List[BaseMessage]): List of BaseMessage to combine. + + Returns: + str: Combined string with necessary newlines. + """ + return "".join( + self._convert_one_message_to_text(message) for message in messages + ) + + def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: + """Format a list of messages into a full prompt for the Anthropic model + + Args: + messages (List[BaseMessage]): List of BaseMessage to combine. + + Returns: + str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. + """ + if not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if not isinstance(messages[-1], AIMessage): + messages.append(AIMessage(content="")) + text = self._convert_messages_to_text(messages) + return ( + text.rstrip() + ) # trim off the trailing ' ' that might come from the "Assistant: " + + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt = self._convert_messages_to_prompt(messages) + params = {"prompt": prompt, "stop_sequences": stop, **self._default_params} + + if self.streaming: + completion = "" + stream_resp = self.client.completion_stream(**params) + for data in stream_resp: + delta = data["completion"][len(completion) :] + completion = data["completion"] + self.callback_manager.on_llm_new_token( + delta, + verbose=self.verbose, + ) + else: + response = self.client.completion(**params) + completion = response["completion"] + message = AIMessage(content=completion) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt = self._convert_messages_to_prompt(messages) + params = {"prompt": prompt, "stop_sequences": stop, **self._default_params} + + if self.streaming: + completion = "" + stream_resp = await self.client.acompletion_stream(**params) + async for data in stream_resp: + delta = data["completion"][len(completion) :] + completion = data["completion"] + if self.callback_manager.is_async: + await self.callback_manager.on_llm_new_token( + delta, + verbose=self.verbose, + ) + else: + self.callback_manager.on_llm_new_token( + delta, + verbose=self.verbose, + ) + else: + response = await self.client.acompletion(**params) + completion = response["completion"] + message = AIMessage(content=completion) + return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index bc4cfd42032..24d9def1eb7 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -1,15 +1,100 @@ """Wrapper around Anthropic APIs.""" import re -from typing import Any, Dict, Generator, List, Mapping, Optional +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional -from pydantic import Extra, root_validator +from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env -class Anthropic(LLM): - r"""Wrapper around Anthropic large language models. +class _AnthropicCommon(BaseModel): + client: Any = None #: :meta private: + model: str = "claude-latest" + """Model name to use.""" + + max_tokens_to_sample: int = 256 + """Denotes the number of tokens to predict per generation.""" + + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" + + top_k: Optional[int] = None + """Number of most likely tokens to consider at each step.""" + + top_p: Optional[float] = None + """Total probability mass of tokens to consider at each step.""" + + streaming: bool = False + """Whether to stream the results.""" + + anthropic_api_key: Optional[str] = None + + HUMAN_PROMPT: Optional[str] = None + AI_PROMPT: Optional[str] = None + count_tokens: Optional[Callable[[str], int]] = None + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + anthropic_api_key = get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) + try: + import anthropic + + values["client"] = anthropic.Client(anthropic_api_key) + values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT + values["AI_PROMPT"] = anthropic.AI_PROMPT + values["count_tokens"] = anthropic.count_tokens + except ImportError: + raise ValueError( + "Could not import anthropic python package. " + "Please it install it with `pip install anthropic`." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Anthropic API.""" + d = { + "max_tokens_to_sample": self.max_tokens_to_sample, + "model": self.model, + } + if self.temperature is not None: + d["temperature"] = self.temperature + if self.top_k is not None: + d["top_k"] = self.top_k + if self.top_p is not None: + d["top_p"] = self.top_p + return d + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + + def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if stop is None: + stop = [] + + # Never want model to invent new turns of Human / Assistant dialog. + stop.extend([self.HUMAN_PROMPT]) + + return stop + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + if not self.count_tokens: + raise NameError("Please ensure the anthropic package is loaded") + return self.count_tokens(text) + + +class Anthropic(LLM, _AnthropicCommon): + r"""Wrapper around Anthropic's large language models. To use, you should have the ``anthropic`` python package installed, and the environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass @@ -32,73 +117,15 @@ class Anthropic(LLM): response = model(prompt) """ - client: Any #: :meta private: - model: str = "claude-v1" - """Model name to use.""" - - max_tokens_to_sample: int = 256 - """Denotes the number of tokens to predict per generation.""" - - temperature: float = 1.0 - """A non-negative float that tunes the degree of randomness in generation.""" - - top_k: int = 0 - """Number of most likely tokens to consider at each step.""" - - top_p: float = 1 - """Total probability mass of tokens to consider at each step.""" - - streaming: bool = False - """Whether to stream the results.""" - - anthropic_api_key: Optional[str] = None - - HUMAN_PROMPT: Optional[str] = None - AI_PROMPT: Optional[str] = None - class Config: """Configuration for this pydantic object.""" extra = Extra.forbid - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - anthropic_api_key = get_from_dict_or_env( - values, "anthropic_api_key", "ANTHROPIC_API_KEY" - ) - try: - import anthropic - - values["client"] = anthropic.Client(anthropic_api_key) - values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT - values["AI_PROMPT"] = anthropic.AI_PROMPT - except ImportError: - raise ValueError( - "Could not import anthropic python package. " - "Please install it with `pip install anthropic`." - ) - return values - - @property - def _default_params(self) -> Mapping[str, Any]: - """Get the default parameters for calling Anthropic API.""" - return { - "max_tokens_to_sample": self.max_tokens_to_sample, - "temperature": self.temperature, - "top_k": self.top_k, - "top_p": self.top_p, - } - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return {**{"model": self.model}, **self._default_params} - @property def _llm_type(self) -> str: """Return type of llm.""" - return "anthropic" + return "anthropic-llm" def _wrap_prompt(self, prompt: str) -> str: if not self.HUMAN_PROMPT or not self.AI_PROMPT: @@ -115,18 +142,6 @@ class Anthropic(LLM): # As a last resort, wrap the prompt ourselves to emulate instruct-style. return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" - def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: - if not self.HUMAN_PROMPT or not self.AI_PROMPT: - raise NameError("Please ensure the anthropic package is loaded") - - if stop is None: - stop = [] - - # Never want model to invent new turns of Human / Assistant dialog. - stop.extend([self.HUMAN_PROMPT, self.AI_PROMPT]) - - return stop - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: r"""Call out to Anthropic's completion endpoint. @@ -148,10 +163,8 @@ class Anthropic(LLM): stop = self._get_anthropic_stop(stop) if self.streaming: stream_resp = self.client.completion_stream( - model=self.model, prompt=self._wrap_prompt(prompt), stop_sequences=stop, - stream=True, **self._default_params, ) current_completion = "" @@ -163,7 +176,6 @@ class Anthropic(LLM): ) return current_completion response = self.client.completion( - model=self.model, prompt=self._wrap_prompt(prompt), stop_sequences=stop, **self._default_params, @@ -175,10 +187,8 @@ class Anthropic(LLM): stop = self._get_anthropic_stop(stop) if self.streaming: stream_resp = await self.client.acompletion_stream( - model=self.model, prompt=self._wrap_prompt(prompt), stop_sequences=stop, - stream=True, **self._default_params, ) current_completion = "" @@ -195,7 +205,6 @@ class Anthropic(LLM): ) return current_completion response = await self.client.acompletion( - model=self.model, prompt=self._wrap_prompt(prompt), stop_sequences=stop, **self._default_params, @@ -227,7 +236,6 @@ class Anthropic(LLM): """ stop = self._get_anthropic_stop(stop) return self.client.completion_stream( - model=self.model, prompt=self._wrap_prompt(prompt), stop_sequences=stop, **self._default_params, diff --git a/tests/integration_tests/chat_models/test_anthropic.py b/tests/integration_tests/chat_models/test_anthropic.py new file mode 100644 index 00000000000..f04b30e2514 --- /dev/null +++ b/tests/integration_tests/chat_models/test_anthropic.py @@ -0,0 +1,81 @@ +"""Test Anthropic API wrapper.""" +from typing import List + +import pytest + +from langchain.callbacks.base import CallbackManager +from langchain.chat_models.anthropic import ChatAnthropic +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + HumanMessage, + LLMResult, +) +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def test_anthropic_call() -> None: + """Test valid call to anthropic.""" + chat = ChatAnthropic(model="bare-nano-0") + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_anthropic_streaming() -> None: + """Test streaming tokens from anthropic.""" + chat = ChatAnthropic(model="bare-nano-0", streaming=True) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_anthropic_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatAnthropic( + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 100 words.") + chat([message]) + assert callback_handler.llm_streams > 1 + + +@pytest.mark.asyncio +async def test_anthropic_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatAnthropic( + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + chat_messages: List[BaseMessage] = [ + HumanMessage(content="How many toes do dogs have?") + ] + result: LLMResult = await chat.agenerate([chat_messages]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + + +def test_formatting() -> None: + chat = ChatAnthropic() + + chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")] + result = chat._convert_messages_to_prompt(chat_messages) + assert result == "\n\nHuman: Hello\n\nAssistant:" + + chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")] + result = chat._convert_messages_to_prompt(chat_messages) + assert result == "\n\nHuman: Hello\n\nAssistant: Answer:" diff --git a/tests/integration_tests/llms/test_anthropic.py b/tests/integration_tests/llms/test_anthropic.py index eaa509bf644..8c7717cfc7d 100644 --- a/tests/integration_tests/llms/test_anthropic.py +++ b/tests/integration_tests/llms/test_anthropic.py @@ -32,7 +32,6 @@ def test_anthropic_streaming_callback() -> None: callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) llm = Anthropic( - model="claude-v1", streaming=True, callback_manager=callback_manager, verbose=True, @@ -55,7 +54,6 @@ async def test_anthropic_async_streaming_callback() -> None: callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) llm = Anthropic( - model="claude-v1", streaming=True, callback_manager=callback_manager, verbose=True,