mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
2 Commits
eugene/why
...
harrison/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e51f10ba1d | ||
|
|
6e37307cab |
@@ -1,139 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatMessage,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
r"""Wrapper around Anthropic's large language model.
|
||||
|
||||
To use, you should have the ``anthropic`` python package installed, and the
|
||||
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
import anthropic
|
||||
from langchain.llms import Anthropic
|
||||
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "anthropic-chat"
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of strings into a single string with necessary newlines.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary newlines.
|
||||
"""
|
||||
return "".join(
|
||||
self._convert_one_message_to_text(message) for message in messages
|
||||
)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
||||
"""
|
||||
if not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
if not isinstance(messages[-1], AIMessage):
|
||||
messages.append(AIMessage(content=""))
|
||||
text = self._convert_messages_to_text(messages)
|
||||
return (
|
||||
text.rstrip()
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
stream_resp = self.client.completion_stream(**params)
|
||||
for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
response = self.client.completion(**params)
|
||||
completion = response["completion"]
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
stream_resp = await self.client.acompletion_stream(**params)
|
||||
async for data in stream_resp:
|
||||
delta = data["completion"][len(completion) :]
|
||||
completion = data["completion"]
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
delta,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
response = await self.client.acompletion(**params)
|
||||
completion = response["completion"]
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
@@ -5,6 +5,13 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@@ -127,6 +134,51 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
"""Return type of llm."""
|
||||
return "anthropic-llm"
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of strings into a single string with necessary newlines.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary newlines.
|
||||
"""
|
||||
return "".join(
|
||||
self._convert_one_message_to_text(message) for message in messages
|
||||
)
|
||||
|
||||
def _convert_messages_to_string(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
||||
"""
|
||||
if not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
if not isinstance(messages[-1], AIMessage):
|
||||
messages.append(AIMessage(content=""))
|
||||
text = self._convert_messages_to_text(messages)
|
||||
return (
|
||||
text.rstrip()
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
|
||||
def _wrap_prompt(self, prompt: str) -> str:
|
||||
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
@@ -10,7 +10,15 @@ from pydantic import Extra, Field, validator
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptType,
|
||||
PromptValue,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@@ -100,16 +108,37 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def _convert_messages_to_string(self, messages: List[BaseMessage]) -> str:
|
||||
return get_buffer_string(messages)
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompt_strings = []
|
||||
for prompt in prompts:
|
||||
if prompt.type == PromptType.string:
|
||||
prompt_strings.append(prompt.to_string())
|
||||
elif prompt.type == PromptType.messages:
|
||||
prompt_strings.append(
|
||||
self._convert_messages_to_string(prompt.to_messages())
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected prompt type: {prompt.type}")
|
||||
return self.generate(prompt_strings, stop=stop)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompt_strings = []
|
||||
for prompt in prompts:
|
||||
if prompt.type == PromptType.string:
|
||||
prompt_strings.append(prompt.to_string())
|
||||
elif prompt.type == PromptType.messages:
|
||||
prompt_strings.append(
|
||||
self._convert_messages_to_string(prompt.to_messages())
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected prompt type: {prompt.type}")
|
||||
return await self.agenerate(prompt_strings, stop=stop)
|
||||
|
||||
def generate(
|
||||
|
||||
@@ -15,6 +15,7 @@ from langchain.schema import (
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
PromptType,
|
||||
PromptValue,
|
||||
SystemMessage,
|
||||
)
|
||||
@@ -108,6 +109,7 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
type: PromptType = PromptType.messages
|
||||
messages: List[BaseMessage]
|
||||
|
||||
def to_string(self) -> str:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
@@ -171,7 +172,14 @@ class LLMResult(BaseModel):
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
string = "string"
|
||||
messages = "messages"
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
type: PromptType = PromptType.string
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.llms.anthropic import Anthropic
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -17,7 +17,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
def test_anthropic_call() -> None:
|
||||
"""Test valid call to anthropic."""
|
||||
chat = ChatAnthropic(model="test")
|
||||
chat = Anthropic(model="test")
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@@ -26,7 +26,7 @@ def test_anthropic_call() -> None:
|
||||
|
||||
def test_anthropic_streaming() -> None:
|
||||
"""Test streaming tokens from anthropic."""
|
||||
chat = ChatAnthropic(model="test", streaming=True)
|
||||
chat = Anthropic(model="test", streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@@ -37,7 +37,7 @@ def test_anthropic_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatAnthropic(
|
||||
chat = Anthropic(
|
||||
model="test",
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
@@ -53,7 +53,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatAnthropic(
|
||||
chat = Anthropic(
|
||||
model="test",
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
@@ -72,7 +72,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
|
||||
|
||||
def test_formatting() -> None:
|
||||
chat = ChatAnthropic()
|
||||
chat = Anthropic()
|
||||
|
||||
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
|
||||
result = chat._convert_messages_to_prompt(chat_messages)
|
||||
|
||||
Reference in New Issue
Block a user