Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
e51f10ba1d wip: anthropic model 2023-04-20 17:25:13 -07:00
Harrison Chase
6e37307cab wip: anthropic formatting 2023-04-20 17:21:39 -07:00
6 changed files with 100 additions and 148 deletions

View File

@@ -1,139 +0,0 @@
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
r"""Wrapper around Anthropic's large language model.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain.llms import Anthropic
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat"
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return "".join(
self._convert_one_message_to_text(message) for message in messages
)
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
if not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = self._convert_messages_to_text(messages)
return (
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
if stop:
params["stop_sequences"] = stop
if self.streaming:
completion = ""
stream_resp = self.client.completion_stream(**params)
for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = self.client.completion(**params)
completion = response["completion"]
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
if stop:
params["stop_sequences"] = stop
if self.streaming:
completion = ""
stream_resp = await self.client.acompletion_stream(**params)
async for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = await self.client.acompletion(**params)
completion = response["completion"]
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@@ -5,6 +5,13 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
@@ -127,6 +134,51 @@ class Anthropic(LLM, _AnthropicCommon):
"""Return type of llm."""
return "anthropic-llm"
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return "".join(
self._convert_one_message_to_text(message) for message in messages
)
def _convert_messages_to_string(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
if not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = self._convert_messages_to_text(messages)
return (
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")

View File

@@ -10,7 +10,15 @@ from pydantic import Extra, Field, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
from langchain.schema import (
BaseLanguageModel,
BaseMessage,
Generation,
LLMResult,
PromptType,
PromptValue,
get_buffer_string,
)
def _get_verbosity() -> bool:
@@ -100,16 +108,37 @@ class BaseLLM(BaseLanguageModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
def _convert_messages_to_string(self, messages: List[BaseMessage]) -> str:
return get_buffer_string(messages)
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
prompt_strings = []
for prompt in prompts:
if prompt.type == PromptType.string:
prompt_strings.append(prompt.to_string())
elif prompt.type == PromptType.messages:
prompt_strings.append(
self._convert_messages_to_string(prompt.to_messages())
)
else:
raise ValueError(f"Unexpected prompt type: {prompt.type}")
return self.generate(prompt_strings, stop=stop)
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
prompt_strings = []
for prompt in prompts:
if prompt.type == PromptType.string:
prompt_strings.append(prompt.to_string())
elif prompt.type == PromptType.messages:
prompt_strings.append(
self._convert_messages_to_string(prompt.to_messages())
)
else:
raise ValueError(f"Unexpected prompt type: {prompt.type}")
return await self.agenerate(prompt_strings, stop=stop)
def generate(

View File

@@ -15,6 +15,7 @@ from langchain.schema import (
BaseMessage,
ChatMessage,
HumanMessage,
PromptType,
PromptValue,
SystemMessage,
)
@@ -108,6 +109,7 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
class ChatPromptValue(PromptValue):
type: PromptType = PromptType.messages
messages: List[BaseMessage]
def to_string(self) -> str:

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
Dict,
@@ -171,7 +172,14 @@ class LLMResult(BaseModel):
"""For arbitrary LLM provider specific output."""
class PromptType(Enum):
string = "string"
messages = "messages"
class PromptValue(BaseModel, ABC):
type: PromptType = PromptType.string
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""

View File

@@ -4,7 +4,7 @@ from typing import List
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.llms.anthropic import Anthropic
from langchain.schema import (
AIMessage,
BaseMessage,
@@ -17,7 +17,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model="test")
chat = Anthropic(model="test")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
@@ -26,7 +26,7 @@ def test_anthropic_call() -> None:
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model="test", streaming=True)
chat = Anthropic(model="test", streaming=True)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
@@ -37,7 +37,7 @@ def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
chat = Anthropic(
model="test",
streaming=True,
callback_manager=callback_manager,
@@ -53,7 +53,7 @@ async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
chat = Anthropic(
model="test",
streaming=True,
callback_manager=callback_manager,
@@ -72,7 +72,7 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_formatting() -> None:
chat = ChatAnthropic()
chat = Anthropic()
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = chat._convert_messages_to_prompt(chat_messages)