From 8a02fd9c01686a9f7aba524e3ffb6ad6c42cca58 Mon Sep 17 00:00:00 2001 From: langchain-infra Date: Fri, 13 Sep 2024 12:39:58 -0400 Subject: [PATCH] core: add additional import mappings to loads (#26406) Support using additional import mapping. This allows users to override old mappings/add new imports to the loads function. - [x ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --- libs/core/langchain_core/load/load.py | 49 ++++++++++++--- .../unit_tests/load/test_serializable.py | 60 ++++++++++++++++++- 2 files changed, 100 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index fb6e3ec2a1d..93857adf583 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -1,7 +1,7 @@ import importlib import json import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from langchain_core._api import beta from langchain_core.load.mapping import ( @@ -36,6 +36,9 @@ class Reviver: secrets_map: Optional[Dict[str, str]] = None, valid_namespaces: Optional[List[str]] = None, secrets_from_env: bool = True, + additional_import_mappings: Optional[ + Dict[Tuple[str, ...], Tuple[str, ...]] + ] = None, ) -> None: """Initialize the reviver. @@ -47,15 +50,27 @@ class Reviver: to allow to be deserialized. Defaults to None. secrets_from_env: Whether to load secrets from the environment. Defaults to True. + additional_import_mappings: A dictionary of additional namespace mappings + You can use this to override default mappings or add new mappings. + Defaults to None. """ self.secrets_from_env = secrets_from_env self.secrets_map = secrets_map or dict() - # By default only support langchain, but user can pass in additional namespaces + # By default, only support langchain, but user can pass in additional namespaces self.valid_namespaces = ( [*DEFAULT_NAMESPACES, *valid_namespaces] if valid_namespaces else DEFAULT_NAMESPACES ) + self.additional_import_mappings = additional_import_mappings or dict() + self.import_mappings = ( + { + **ALL_SERIALIZABLE_MAPPINGS, + **self.additional_import_mappings, + } + if self.additional_import_mappings + else ALL_SERIALIZABLE_MAPPINGS + ) def __call__(self, value: Dict[str, Any]) -> Any: if ( @@ -96,16 +111,16 @@ class Reviver: raise ValueError(f"Invalid namespace: {value}") # If namespace is in known namespaces, try to use mapping + key = tuple(namespace + [name]) if namespace[0] in DEFAULT_NAMESPACES: # Get the importable path - key = tuple(namespace + [name]) - if key not in ALL_SERIALIZABLE_MAPPINGS: + if key not in self.import_mappings: raise ValueError( "Trying to deserialize something that cannot " "be deserialized in current version of langchain-core: " f"{key}" ) - import_path = ALL_SERIALIZABLE_MAPPINGS[key] + import_path = self.import_mappings[key] # Split into module and name import_dir, import_obj = import_path[:-1], import_path[-1] # Import module @@ -114,7 +129,12 @@ class Reviver: cls = getattr(mod, import_obj) # Otherwise, load by path else: - mod = importlib.import_module(".".join(namespace)) + if key in self.additional_import_mappings: + import_path = self.import_mappings[key] + mod = importlib.import_module(".".join(import_path[:-1])) + name = import_path[-1] + else: + mod = importlib.import_module(".".join(namespace)) cls = getattr(mod, name) # The class must be a subclass of Serializable. @@ -136,6 +156,7 @@ def loads( secrets_map: Optional[Dict[str, str]] = None, valid_namespaces: Optional[List[str]] = None, secrets_from_env: bool = True, + additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, ) -> Any: """Revive a LangChain class from a JSON string. Equivalent to `load(json.loads(text))`. @@ -149,12 +170,18 @@ def loads( to allow to be deserialized. Defaults to None. secrets_from_env: Whether to load secrets from the environment. Defaults to True. + additional_import_mappings: A dictionary of additional namespace mappings + You can use this to override default mappings or add new mappings. + Defaults to None. Returns: Revived LangChain objects. """ return json.loads( - text, object_hook=Reviver(secrets_map, valid_namespaces, secrets_from_env) + text, + object_hook=Reviver( + secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings + ), ) @@ -165,6 +192,7 @@ def load( secrets_map: Optional[Dict[str, str]] = None, valid_namespaces: Optional[List[str]] = None, secrets_from_env: bool = True, + additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, ) -> Any: """Revive a LangChain class from a JSON object. Use this if you already have a parsed JSON object, eg. from `json.load` or `orjson.loads`. @@ -178,11 +206,16 @@ def load( to allow to be deserialized. Defaults to None. secrets_from_env: Whether to load secrets from the environment. Defaults to True. + additional_import_mappings: A dictionary of additional namespace mappings + You can use this to override default mappings or add new mappings. + Defaults to None. Returns: Revived LangChain objects. """ - reviver = Reviver(secrets_map, valid_namespaces, secrets_from_env) + reviver = Reviver( + secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings + ) def _load(obj: Any) -> Any: if isinstance(obj, dict): diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index 9ef83a6c89c..f03bdd91c6b 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,6 +1,6 @@ from typing import Dict -from langchain_core.load import Serializable, dumpd +from langchain_core.load import Serializable, dumpd, load from langchain_core.load.serializable import _is_field_useful from langchain_core.pydantic_v1 import Field @@ -107,3 +107,61 @@ def test__is_field_useful() -> None: foo = Foo(x=default_x, y=default_y, z=ArrayObj()) assert not _is_field_useful(foo, "x", foo.x) assert not _is_field_useful(foo, "y", foo.y) + + +class Foo(Serializable): + bar: int + baz: str + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +def test_simple_deserialization() -> None: + foo = Foo(bar=1, baz="hello") + assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"] + serialized_foo = dumpd(foo) + assert serialized_foo == { + "id": ["tests", "unit_tests", "load", "test_serializable", "Foo"], + "kwargs": {"bar": 1, "baz": "hello"}, + "lc": 1, + "type": "constructor", + } + new_foo = load(serialized_foo, valid_namespaces=["tests"]) + assert new_foo == foo + + +class Foo2(Serializable): + bar: int + baz: str + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +def test_simple_deserialization_with_additional_imports() -> None: + foo = Foo(bar=1, baz="hello") + assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"] + serialized_foo = dumpd(foo) + assert serialized_foo == { + "id": ["tests", "unit_tests", "load", "test_serializable", "Foo"], + "kwargs": {"bar": 1, "baz": "hello"}, + "lc": 1, + "type": "constructor", + } + new_foo = load( + serialized_foo, + valid_namespaces=["tests"], + additional_import_mappings={ + ("tests", "unit_tests", "load", "test_serializable", "Foo"): ( + "tests", + "unit_tests", + "load", + "test_serializable", + "Foo2", + ) + }, + ) + assert isinstance(new_foo, Foo2)