mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
Accept message-like things in Chat models, LLMs and MessagesPlaceholder (#16418)
This commit is contained in:
parent
570b4f8e66
commit
52ccae3fb1
@ -16,7 +16,12 @@ from typing import (
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
from langchain_core.messages import (
|
||||||
|
AnyMessage,
|
||||||
|
BaseMessage,
|
||||||
|
MessageLikeRepresentation,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
|||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
|
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||||
LanguageModelOutput = Union[BaseMessage, str]
|
LanguageModelOutput = Union[BaseMessage, str]
|
||||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||||
|
@ -34,6 +34,7 @@ from langchain_core.messages import (
|
|||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
convert_to_messages,
|
||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
elif isinstance(input, str):
|
elif isinstance(input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=input)
|
||||||
elif isinstance(input, Sequence):
|
elif isinstance(input, Sequence):
|
||||||
return ChatPromptValue(messages=input)
|
return ChatPromptValue(messages=convert_to_messages(input))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
|
@ -48,7 +48,12 @@ from langchain_core.callbacks import (
|
|||||||
from langchain_core.globals import get_llm_cache
|
from langchain_core.globals import get_llm_cache
|
||||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
convert_to_messages,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||||
@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
elif isinstance(input, str):
|
elif isinstance(input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=input)
|
||||||
elif isinstance(input, Sequence):
|
elif isinstance(input, Sequence):
|
||||||
return ChatPromptValue(messages=input)
|
return ChatPromptValue(messages=convert_to_messages(input))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_message_from_message_type(
|
||||||
|
message_type: str,
|
||||||
|
content: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
tool_call_id: Optional[str] = None,
|
||||||
|
**additional_kwargs: Any,
|
||||||
|
) -> BaseMessage:
|
||||||
|
"""Create a message from a message type and content string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_type: str the type of the message (e.g., "human", "ai", etc.)
|
||||||
|
content: str the content string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a message of the appropriate type.
|
||||||
|
"""
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if name is not None:
|
||||||
|
kwargs["name"] = name
|
||||||
|
if tool_call_id is not None:
|
||||||
|
kwargs["tool_call_id"] = tool_call_id
|
||||||
|
if additional_kwargs:
|
||||||
|
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
|
||||||
|
if message_type in ("human", "user"):
|
||||||
|
message: BaseMessage = HumanMessage(content=content, **kwargs)
|
||||||
|
elif message_type in ("ai", "assistant"):
|
||||||
|
message = AIMessage(content=content, **kwargs)
|
||||||
|
elif message_type == "system":
|
||||||
|
message = SystemMessage(content=content, **kwargs)
|
||||||
|
elif message_type == "function":
|
||||||
|
message = FunctionMessage(content=content, **kwargs)
|
||||||
|
elif message_type == "tool":
|
||||||
|
message = ToolMessage(content=content, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||||
|
f" 'user', 'ai', 'assistant', or 'system'."
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_message(
|
||||||
|
message: MessageLikeRepresentation,
|
||||||
|
) -> BaseMessage:
|
||||||
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
|
The message format can be one of the following:
|
||||||
|
|
||||||
|
- BaseMessagePromptTemplate
|
||||||
|
- BaseMessage
|
||||||
|
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||||
|
- dict: a message dict with role and content keys
|
||||||
|
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: a representation of a message in one of the supported formats
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an instance of a message or a message template
|
||||||
|
"""
|
||||||
|
if isinstance(message, BaseMessage):
|
||||||
|
_message = message
|
||||||
|
elif isinstance(message, str):
|
||||||
|
_message = _create_message_from_message_type("human", message)
|
||||||
|
elif isinstance(message, tuple):
|
||||||
|
if len(message) != 2:
|
||||||
|
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||||
|
message_type_str, template = message
|
||||||
|
_message = _create_message_from_message_type(message_type_str, template)
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
msg_kwargs = message.copy()
|
||||||
|
try:
|
||||||
|
msg_type = msg_kwargs.pop("role")
|
||||||
|
msg_content = msg_kwargs.pop("content")
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||||
|
)
|
||||||
|
_message = _create_message_from_message_type(
|
||||||
|
msg_type, msg_content, **msg_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_messages(
|
||||||
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Convert a sequence of messages to a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages (BaseMessages).
|
||||||
|
"""
|
||||||
|
return [_convert_to_message(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIMessage",
|
"AIMessage",
|
||||||
"AIMessageChunk",
|
"AIMessageChunk",
|
||||||
@ -133,6 +237,7 @@ __all__ = [
|
|||||||
"SystemMessageChunk",
|
"SystemMessageChunk",
|
||||||
"ToolMessage",
|
"ToolMessage",
|
||||||
"ToolMessageChunk",
|
"ToolMessageChunk",
|
||||||
|
"convert_to_messages",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"message_chunk_to_message",
|
"message_chunk_to_message",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
|
@ -27,6 +27,7 @@ from langchain_core.messages import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
convert_to_messages,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.base import get_msg_title_repr
|
from langchain_core.messages.base import get_msg_title_repr
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||||
@ -126,7 +127,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
f"variable {self.variable_name} should be a list of base messages, "
|
f"variable {self.variable_name} should be a list of base messages, "
|
||||||
f"got {value}"
|
f"got {value}"
|
||||||
)
|
)
|
||||||
for v in value:
|
for v in convert_to_messages(value):
|
||||||
if not isinstance(v, BaseMessage):
|
if not isinstance(v, BaseMessage):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"variable {self.variable_name} should be a list of base messages,"
|
f"variable {self.variable_name} should be a list of base messages,"
|
||||||
|
@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "generic-fake-chat-model"
|
return "generic-fake-chat-model"
|
||||||
|
|
||||||
|
|
||||||
|
class ParrotFakeChatModel(BaseChatModel):
|
||||||
|
"""A generic fake chat model that can be used to test the chat model interface.
|
||||||
|
|
||||||
|
* Chat model should be usable in both sync and async tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "parrot-fake-chat-model"
|
||||||
|
@ -5,8 +5,9 @@ from uuid import UUID
|
|||||||
|
|
||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.messages.human import HumanMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
|
||||||
|
|
||||||
|
|
||||||
def test_generic_fake_chat_model_invoke() -> None:
|
def test_generic_fake_chat_model_invoke() -> None:
|
||||||
@ -182,3 +183,11 @@ async def test_callback_handlers() -> None:
|
|||||||
AIMessageChunk(content="goodbye"),
|
AIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
assert tokens == ["hello", " ", "goodbye"]
|
assert tokens == ["hello", " ", "goodbye"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_model_inputs() -> None:
|
||||||
|
fake = ParrotFakeChatModel()
|
||||||
|
|
||||||
|
assert fake.invoke("hello") == HumanMessage(content="hello")
|
||||||
|
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
|
||||||
|
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
|
||||||
|
@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
|||||||
"SystemMessageChunk",
|
"SystemMessageChunk",
|
||||||
"ToolMessage",
|
"ToolMessage",
|
||||||
"ToolMessageChunk",
|
"ToolMessageChunk",
|
||||||
|
"convert_to_messages",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"message_chunk_to_message",
|
"message_chunk_to_message",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
|
@ -369,3 +369,9 @@ def test_messages_placeholder() -> None:
|
|||||||
prompt.format_messages()
|
prompt.format_messages()
|
||||||
prompt = MessagesPlaceholder("history", optional=True)
|
prompt = MessagesPlaceholder("history", optional=True)
|
||||||
assert prompt.format_messages() == []
|
assert prompt.format_messages() == []
|
||||||
|
prompt.format_messages(
|
||||||
|
history=[("system", "You are an AI assistant."), "Hello!"]
|
||||||
|
) == [
|
||||||
|
SystemMessage(content="You are an AI assistant."),
|
||||||
|
HumanMessage(content="Hello!"),
|
||||||
|
]
|
||||||
|
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
|||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
|
convert_to_messages,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
message_chunk_to_message,
|
message_chunk_to_message,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_messages() -> None:
|
||||||
|
# dicts
|
||||||
|
assert convert_to_messages(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "ai", "content": "Hi!"},
|
||||||
|
{"role": "human", "content": "Hello!", "name": "Jane"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hi!",
|
||||||
|
"name": "JaneBot",
|
||||||
|
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
|
||||||
|
},
|
||||||
|
{"role": "function", "name": "greet", "content": "Hi!"},
|
||||||
|
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="Hello!"),
|
||||||
|
AIMessage(content="Hi!"),
|
||||||
|
HumanMessage(content="Hello!", name="Jane"),
|
||||||
|
AIMessage(
|
||||||
|
content="Hi!",
|
||||||
|
name="JaneBot",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
FunctionMessage(name="greet", content="Hi!"),
|
||||||
|
ToolMessage(tool_call_id="tool_id", content="Hi!"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# tuples
|
||||||
|
assert convert_to_messages(
|
||||||
|
[
|
||||||
|
("system", "You are a helpful assistant."),
|
||||||
|
"hello!",
|
||||||
|
("ai", "Hi!"),
|
||||||
|
("human", "Hello!"),
|
||||||
|
("assistant", "Hi!"),
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="hello!"),
|
||||||
|
AIMessage(content="Hi!"),
|
||||||
|
HumanMessage(content="Hello!"),
|
||||||
|
AIMessage(content="Hi!"),
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user