diff --git a/libs/langchain/langchain/tools/json/tool.py b/libs/langchain/langchain/tools/json/tool.py index 6f6473d51e6..6c75de20ce5 100644 --- a/libs/langchain/langchain/tools/json/tool.py +++ b/libs/langchain/langchain/tools/json/tool.py @@ -20,7 +20,7 @@ def _parse_input(text: str) -> List[Union[str, int]]: """Parse input of the form data["key1"][0]["key2"] into a list of keys.""" _res = re.findall(r"\[.*?]", text) # strip the brackets and quotes, convert to int if possible - res = [i[1:-1].replace('"', "") for i in _res] + res = [i[1:-1].replace('"', "").replace("'", "") for i in _res] res = [int(i) if i.isdigit() else i for i in res] return res diff --git a/libs/langchain/tests/unit_tests/tools/test_json.py b/libs/langchain/tests/unit_tests/tools/test_json.py index 36a96595e03..b677b1577d3 100644 --- a/libs/langchain/tests/unit_tests/tools/test_json.py +++ b/libs/langchain/tests/unit_tests/tools/test_json.py @@ -30,6 +30,10 @@ def test_json_spec_value() -> None: assert spec.value('data["baz"]') == "{'test': {'foo': [1, 2, 3]}}" assert spec.value('data["baz"]["test"]') == "{'foo': [1, 2, 3]}" assert spec.value('data["baz"]["test"]["foo"]') == "[1, 2, 3]" + assert spec.value("data['foo']") == "bar" + assert spec.value("data['baz']") == "{'test': {'foo': [1, 2, 3]}}" + assert spec.value("data['baz']['test']") == "{'foo': [1, 2, 3]}" + assert spec.value("data['baz']['test']['foo']") == "[1, 2, 3]" def test_json_spec_value_max_length() -> None: