From 47cfbe7522ab449d930212ec36600ab613954d68 Mon Sep 17 00:00:00 2001 From: Luca Dorigo Date: Fri, 22 Mar 2024 22:33:50 +0100 Subject: [PATCH] openai[patch]: [URGENT REGRESSION FIX] Don't fail if tool message already doesn't contain name (#19435) - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../langchain_openai/chat_models/base.py | 15 ++++---------- .../tests/unit_tests/chat_models/test_base.py | 20 ++++++++++++++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ed97a6c7379..4b8ec3e016e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -141,14 +141,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] = { "content": message.content, } - if message.name is not None: - message_dict["name"] = message.name - elif ( - "name" in message.additional_kwargs - and message.additional_kwargs["name"] is not None - ): - # fall back on additional kwargs for backwards compatibility - message_dict["name"] = message.additional_kwargs["name"] + if (name := message.name or message.additional_kwargs.get("name")) is not None: + message_dict["name"] = name # populate role and additional message data if isinstance(message, ChatMessage): @@ -175,9 +169,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["role"] = "tool" message_dict["tool_call_id"] = message.tool_call_id - # tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages - if message_dict["name"] is None: - del message_dict["name"] + supported_props = {"content", "role", "tool_call_id"} + message_dict = {k: v for k, v in message_dict.items() if k in supported_props} else: raise TypeError(f"Got unknown type {message}") return message_dict diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 87e7111959e..4a9a6498057 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -10,10 +10,14 @@ from langchain_core.messages import ( FunctionMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_openai import ChatOpenAI -from langchain_openai.chat_models.base import _convert_dict_to_message +from langchain_openai.chat_models.base import ( + _convert_dict_to_message, + _convert_message_to_dict, +) def test_openai_model_param() -> None: @@ -43,6 +47,7 @@ def test__convert_dict_to_message_human() -> None: result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_human_with_name() -> None: @@ -50,6 +55,7 @@ def test__convert_dict_to_message_human_with_name() -> None: result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_ai() -> None: @@ -57,6 +63,7 @@ def test__convert_dict_to_message_ai() -> None: result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_ai_with_name() -> None: @@ -64,6 +71,7 @@ def test__convert_dict_to_message_ai_with_name() -> None: result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_system() -> None: @@ -71,6 +79,7 @@ def test__convert_dict_to_message_system() -> None: result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_system_with_name() -> None: @@ -78,6 +87,15 @@ def test__convert_dict_to_message_system_with_name() -> None: result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message + + +def test__convert_dict_to_message_tool() -> None: + message = {"role": "tool", "content": "foo", "tool_call_id": "bar"} + result = _convert_dict_to_message(message) + expected_output = ToolMessage(content="foo", tool_call_id="bar") + assert result == expected_output + assert _convert_message_to_dict(expected_output) == message @pytest.fixture