core[patch]: Fixes for convert_messages (#21207)

- support two-tuples of any sequence type (eg. json.loads never produces
tuples)
- support type alias for role key
- if id is passed in in dict form use it
- if tool_calls passed in in dict form use them

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nuno Campos
2024-05-02 09:55:42 -07:00
committed by GitHub
parent df49404794
commit 663747b730
3 changed files with 49 additions and 28 deletions

View File

@@ -44,7 +44,7 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }} poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ inputs.working-directory }} working-directory: ${{ inputs.working-directory }}
cache-key: lint-with-extras # cache-key: lint-with-extras
- name: Check Poetry File - name: Check Poetry File
shell: bash shell: bash
@@ -79,14 +79,14 @@ jobs:
run: | run: |
poetry run pip install -e "$LANGCHAIN_LOCATION" poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Get .mypy_cache to speed up mypy # - name: Get .mypy_cache to speed up mypy
uses: actions/cache@v4 # uses: actions/cache@v4
env: # env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" # SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
with: # with:
path: | # path: |
${{ env.WORKDIR }}/.mypy_cache # ${{ env.WORKDIR }}/.mypy_cache
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} # key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint - name: Analysing the code with our lint
@@ -113,14 +113,14 @@ jobs:
run: | run: |
poetry install --with test,test_integration poetry install --with test,test_integration
- name: Get .mypy_cache_test to speed up mypy # - name: Get .mypy_cache_test to speed up mypy
uses: actions/cache@v4 # uses: actions/cache@v4
env: # env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" # SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
with: # with:
path: | # path: |
${{ env.WORKDIR }}/.mypy_cache_test # ${{ env.WORKDIR }}/.mypy_cache_test
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} # key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint - name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }} working-directory: ${{ inputs.working-directory }}

View File

@@ -130,7 +130,9 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
) )
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]] MessageLikeRepresentation = Union[
BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any]
]
def _create_message_from_message_type( def _create_message_from_message_type(
@@ -138,6 +140,8 @@ def _create_message_from_message_type(
content: str, content: str,
name: Optional[str] = None, name: Optional[str] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
id: Optional[str] = None,
**additional_kwargs: Any, **additional_kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Create a message from a message type and content string. """Create a message from a message type and content string.
@@ -156,6 +160,10 @@ def _create_message_from_message_type(
kwargs["tool_call_id"] = tool_call_id kwargs["tool_call_id"] = tool_call_id
if additional_kwargs: if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment] kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if id is not None:
kwargs["id"] = id
if tool_calls is not None:
kwargs["tool_calls"] = tool_calls
if message_type in ("human", "user"): if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs) message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"): elif message_type in ("ai", "assistant"):
@@ -197,15 +205,17 @@ def _convert_to_message(
_message = message _message = message
elif isinstance(message, str): elif isinstance(message, str):
_message = _create_message_from_message_type("human", message) _message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple): elif isinstance(message, Sequence) and len(message) == 2:
if len(message) != 2: # mypy doesn't realise this can't be a string given the previous branch
raise ValueError(f"Expected 2-tuple of (role, template), got {message}") message_type_str, template = message # type: ignore[misc]
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template) _message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict): elif isinstance(message, dict):
msg_kwargs = message.copy() msg_kwargs = message.copy()
try: try:
msg_type = msg_kwargs.pop("role") try:
msg_type = msg_kwargs.pop("role")
except KeyError:
msg_type = msg_kwargs.pop("type")
msg_content = msg_kwargs.pop("content") msg_content = msg_kwargs.pop("content")
except KeyError: except KeyError:
raise ValueError( raise ValueError(

View File

@@ -545,8 +545,8 @@ def test_convert_to_messages() -> None:
[ [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}, {"role": "user", "content": "Hello!"},
{"role": "ai", "content": "Hi!"}, {"role": "ai", "content": "Hi!", "id": "ai1"},
{"role": "human", "content": "Hello!", "name": "Jane"}, {"type": "human", "content": "Hello!", "name": "Jane", "id": "human1"},
{ {
"role": "assistant", "role": "assistant",
"content": "Hi!", "content": "Hi!",
@@ -554,13 +554,20 @@ def test_convert_to_messages() -> None:
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}, "function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
}, },
{"role": "function", "name": "greet", "content": "Hi!"}, {"role": "function", "name": "greet", "content": "Hi!"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"name": "greet", "args": {"name": "Jane"}, "id": "tool_id"}
],
},
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"}, {"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
] ]
) == [ ) == [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello!"), HumanMessage(content="Hello!"),
AIMessage(content="Hi!"), AIMessage(content="Hi!", id="ai1"),
HumanMessage(content="Hello!", name="Jane"), HumanMessage(content="Hello!", name="Jane", id="human1"),
AIMessage( AIMessage(
content="Hi!", content="Hi!",
name="JaneBot", name="JaneBot",
@@ -569,6 +576,10 @@ def test_convert_to_messages() -> None:
}, },
), ),
FunctionMessage(name="greet", content="Hi!"), FunctionMessage(name="greet", content="Hi!"),
AIMessage(
content="",
tool_calls=[ToolCall(name="greet", args={"name": "Jane"}, id="tool_id")],
),
ToolMessage(tool_call_id="tool_id", content="Hi!"), ToolMessage(tool_call_id="tool_id", content="Hi!"),
] ]
@@ -579,7 +590,7 @@ def test_convert_to_messages() -> None:
"hello!", "hello!",
("ai", "Hi!"), ("ai", "Hi!"),
("human", "Hello!"), ("human", "Hello!"),
("assistant", "Hi!"), ["assistant", "Hi!"],
] ]
) == [ ) == [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),