mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Accept message-like things in Chat models, LLMs and MessagesPlaceholder (#16418)
This commit is contained in:
parent
570b4f8e66
commit
52ccae3fb1
@ -16,7 +16,12 @@ from typing import (
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||
LanguageModelOutput = Union[BaseMessage, str]
|
||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||
|
@ -34,6 +34,7 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=input)
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
|
@ -48,7 +48,12 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_messages,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||
@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
elif isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
elif isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=input)
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import (
|
||||
@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
)
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
|
||||
|
||||
|
||||
def _create_message_from_message_type(
|
||||
message_type: str,
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message (e.g., "human", "ai", etc.)
|
||||
content: str the content string.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
kwargs["tool_call_id"] = tool_call_id
|
||||
if additional_kwargs:
|
||||
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessage = HumanMessage(content=content, **kwargs)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessage(content=content, **kwargs)
|
||||
elif message_type == "system":
|
||||
message = SystemMessage(content=content, **kwargs)
|
||||
elif message_type == "function":
|
||||
message = FunctionMessage(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
message = ToolMessage(content=content, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system'."
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> BaseMessage:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- dict: a message dict with role and content keys
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_message_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
_message = _create_message_from_message_type(message_type_str, template)
|
||||
elif isinstance(message, dict):
|
||||
msg_kwargs = message.copy()
|
||||
try:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
msg_content = msg_kwargs.pop("content")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
)
|
||||
_message = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
||||
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
@ -133,6 +237,7 @@ __all__ = [
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
|
@ -27,6 +27,7 @@ from langchain_core.messages import (
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
convert_to_messages,
|
||||
)
|
||||
from langchain_core.messages.base import get_msg_title_repr
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
||||
@ -126,7 +127,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
f"got {value}"
|
||||
)
|
||||
for v in value:
|
||||
for v in convert_to_messages(value):
|
||||
if not isinstance(v, BaseMessage):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages,"
|
||||
|
@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
||||
|
||||
class ParrotFakeChatModel(BaseChatModel):
|
||||
"""A generic fake chat model that can be used to test the chat model interface.
|
||||
|
||||
* Chat model should be usable in both sync and async tests
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
@ -5,8 +5,9 @@ from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
@ -182,3 +183,11 @@ async def test_callback_handlers() -> None:
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
||||
|
||||
|
||||
def test_chat_model_inputs() -> None:
|
||||
fake = ParrotFakeChatModel()
|
||||
|
||||
assert fake.invoke("hello") == HumanMessage(content="hello")
|
||||
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
|
||||
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
|
||||
|
@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
||||
"SystemMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
|
@ -369,3 +369,9 @@ def test_messages_placeholder() -> None:
|
||||
prompt.format_messages()
|
||||
prompt = MessagesPlaceholder("history", optional=True)
|
||||
assert prompt.format_messages() == []
|
||||
prompt.format_messages(
|
||||
history=[("system", "You are an AI assistant."), "Hello!"]
|
||||
) == [
|
||||
SystemMessage(content="You are an AI assistant."),
|
||||
HumanMessage(content="Hello!"),
|
||||
]
|
||||
|
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
convert_to_messages,
|
||||
get_buffer_string,
|
||||
message_chunk_to_message,
|
||||
messages_from_dict,
|
||||
@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_messages() -> None:
|
||||
# dicts
|
||||
assert convert_to_messages(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "ai", "content": "Hi!"},
|
||||
{"role": "human", "content": "Hello!", "name": "Jane"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi!",
|
||||
"name": "JaneBot",
|
||||
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
|
||||
},
|
||||
{"role": "function", "name": "greet", "content": "Hi!"},
|
||||
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
HumanMessage(content="Hello!", name="Jane"),
|
||||
AIMessage(
|
||||
content="Hi!",
|
||||
name="JaneBot",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
|
||||
},
|
||||
),
|
||||
FunctionMessage(name="greet", content="Hi!"),
|
||||
ToolMessage(tool_call_id="tool_id", content="Hi!"),
|
||||
]
|
||||
|
||||
# tuples
|
||||
assert convert_to_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant."),
|
||||
"hello!",
|
||||
("ai", "Hi!"),
|
||||
("human", "Hello!"),
|
||||
("assistant", "Hi!"),
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
HumanMessage(content="Hello!"),
|
||||
AIMessage(content="Hi!"),
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user