From b9636e5c987e1217afcdf83e9c311568ad50c304 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Dec 2023 17:37:12 -0800 Subject: [PATCH] Catch type errors in dumps/dumpd (#15336) These can happen for edge cases not covered by `default` handler (eg. "strange" keys in dicts) --- libs/core/langchain_core/load/dump.py | 16 +++++++++++----- .../langchain/tests/unit_tests/load/test_dump.py | 7 +++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py index 07c956a8400..783ab9271d4 100644 --- a/libs/core/langchain_core/load/dump.py +++ b/libs/core/langchain_core/load/dump.py @@ -17,11 +17,17 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: """Return a json string representation of an object.""" if "default" in kwargs: raise ValueError("`default` should not be passed to dumps") - if pretty: - indent = kwargs.pop("indent", 2) - return json.dumps(obj, default=default, indent=indent, **kwargs) - else: - return json.dumps(obj, default=default, **kwargs) + try: + if pretty: + indent = kwargs.pop("indent", 2) + return json.dumps(obj, default=default, indent=indent, **kwargs) + else: + return json.dumps(obj, default=default, **kwargs) + except TypeError: + if pretty: + return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs) + else: + return json.dumps(to_json_not_implemented(obj), **kwargs) def dumpd(obj: Any) -> Dict[str, Any]: diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 6eef608e101..0553eae7a1e 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -60,6 +60,13 @@ def test_person(snapshot: Any) -> None: assert Person.lc_id() == ["tests", "unit_tests", "load", "test_dump", "Person"] +def test_typeerror() -> None: + assert ( + dumps({(1, 2): 3}) + == """{"lc": 1, "type": "not_implemented", "id": ["builtins", "dict"], "repr": "{(1, 2): 3}"}""" # noqa: E501 + ) + + @pytest.mark.requires("openai") def test_serialize_openai_llm(snapshot: Any) -> None: llm = OpenAI(