mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
core: add additional import mappings to loads (#26406)
Support using additional import mapping. This allows users to override old mappings/add new imports to the loads function. - [x ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/
This commit is contained in:
parent
1d98937e8d
commit
8a02fd9c01
@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.load.mapping import (
|
from langchain_core.load.mapping import (
|
||||||
@ -36,6 +36,9 @@ class Reviver:
|
|||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[Dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[List[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
|
additional_import_mappings: Optional[
|
||||||
|
Dict[Tuple[str, ...], Tuple[str, ...]]
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the reviver.
|
"""Initialize the reviver.
|
||||||
|
|
||||||
@ -47,15 +50,27 @@ class Reviver:
|
|||||||
to allow to be deserialized. Defaults to None.
|
to allow to be deserialized. Defaults to None.
|
||||||
secrets_from_env: Whether to load secrets from the environment.
|
secrets_from_env: Whether to load secrets from the environment.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
additional_import_mappings: A dictionary of additional namespace mappings
|
||||||
|
You can use this to override default mappings or add new mappings.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.secrets_from_env = secrets_from_env
|
self.secrets_from_env = secrets_from_env
|
||||||
self.secrets_map = secrets_map or dict()
|
self.secrets_map = secrets_map or dict()
|
||||||
# By default only support langchain, but user can pass in additional namespaces
|
# By default, only support langchain, but user can pass in additional namespaces
|
||||||
self.valid_namespaces = (
|
self.valid_namespaces = (
|
||||||
[*DEFAULT_NAMESPACES, *valid_namespaces]
|
[*DEFAULT_NAMESPACES, *valid_namespaces]
|
||||||
if valid_namespaces
|
if valid_namespaces
|
||||||
else DEFAULT_NAMESPACES
|
else DEFAULT_NAMESPACES
|
||||||
)
|
)
|
||||||
|
self.additional_import_mappings = additional_import_mappings or dict()
|
||||||
|
self.import_mappings = (
|
||||||
|
{
|
||||||
|
**ALL_SERIALIZABLE_MAPPINGS,
|
||||||
|
**self.additional_import_mappings,
|
||||||
|
}
|
||||||
|
if self.additional_import_mappings
|
||||||
|
else ALL_SERIALIZABLE_MAPPINGS
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||||
if (
|
if (
|
||||||
@ -96,16 +111,16 @@ class Reviver:
|
|||||||
raise ValueError(f"Invalid namespace: {value}")
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
# If namespace is in known namespaces, try to use mapping
|
# If namespace is in known namespaces, try to use mapping
|
||||||
|
key = tuple(namespace + [name])
|
||||||
if namespace[0] in DEFAULT_NAMESPACES:
|
if namespace[0] in DEFAULT_NAMESPACES:
|
||||||
# Get the importable path
|
# Get the importable path
|
||||||
key = tuple(namespace + [name])
|
if key not in self.import_mappings:
|
||||||
if key not in ALL_SERIALIZABLE_MAPPINGS:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Trying to deserialize something that cannot "
|
"Trying to deserialize something that cannot "
|
||||||
"be deserialized in current version of langchain-core: "
|
"be deserialized in current version of langchain-core: "
|
||||||
f"{key}"
|
f"{key}"
|
||||||
)
|
)
|
||||||
import_path = ALL_SERIALIZABLE_MAPPINGS[key]
|
import_path = self.import_mappings[key]
|
||||||
# Split into module and name
|
# Split into module and name
|
||||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||||
# Import module
|
# Import module
|
||||||
@ -113,6 +128,11 @@ class Reviver:
|
|||||||
# Import class
|
# Import class
|
||||||
cls = getattr(mod, import_obj)
|
cls = getattr(mod, import_obj)
|
||||||
# Otherwise, load by path
|
# Otherwise, load by path
|
||||||
|
else:
|
||||||
|
if key in self.additional_import_mappings:
|
||||||
|
import_path = self.import_mappings[key]
|
||||||
|
mod = importlib.import_module(".".join(import_path[:-1]))
|
||||||
|
name = import_path[-1]
|
||||||
else:
|
else:
|
||||||
mod = importlib.import_module(".".join(namespace))
|
mod = importlib.import_module(".".join(namespace))
|
||||||
cls = getattr(mod, name)
|
cls = getattr(mod, name)
|
||||||
@ -136,6 +156,7 @@ def loads(
|
|||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[Dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[List[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
|
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Revive a LangChain class from a JSON string.
|
"""Revive a LangChain class from a JSON string.
|
||||||
Equivalent to `load(json.loads(text))`.
|
Equivalent to `load(json.loads(text))`.
|
||||||
@ -149,12 +170,18 @@ def loads(
|
|||||||
to allow to be deserialized. Defaults to None.
|
to allow to be deserialized. Defaults to None.
|
||||||
secrets_from_env: Whether to load secrets from the environment.
|
secrets_from_env: Whether to load secrets from the environment.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
additional_import_mappings: A dictionary of additional namespace mappings
|
||||||
|
You can use this to override default mappings or add new mappings.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Revived LangChain objects.
|
Revived LangChain objects.
|
||||||
"""
|
"""
|
||||||
return json.loads(
|
return json.loads(
|
||||||
text, object_hook=Reviver(secrets_map, valid_namespaces, secrets_from_env)
|
text,
|
||||||
|
object_hook=Reviver(
|
||||||
|
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -165,6 +192,7 @@ def load(
|
|||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[Dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[List[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
|
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||||
@ -178,11 +206,16 @@ def load(
|
|||||||
to allow to be deserialized. Defaults to None.
|
to allow to be deserialized. Defaults to None.
|
||||||
secrets_from_env: Whether to load secrets from the environment.
|
secrets_from_env: Whether to load secrets from the environment.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
additional_import_mappings: A dictionary of additional namespace mappings
|
||||||
|
You can use this to override default mappings or add new mappings.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Revived LangChain objects.
|
Revived LangChain objects.
|
||||||
"""
|
"""
|
||||||
reviver = Reviver(secrets_map, valid_namespaces, secrets_from_env)
|
reviver = Reviver(
|
||||||
|
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
|
||||||
|
)
|
||||||
|
|
||||||
def _load(obj: Any) -> Any:
|
def _load(obj: Any) -> Any:
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from langchain_core.load import Serializable, dumpd
|
from langchain_core.load import Serializable, dumpd, load
|
||||||
from langchain_core.load.serializable import _is_field_useful
|
from langchain_core.load.serializable import _is_field_useful
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
@ -107,3 +107,61 @@ def test__is_field_useful() -> None:
|
|||||||
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
|
||||||
assert not _is_field_useful(foo, "x", foo.x)
|
assert not _is_field_useful(foo, "x", foo.x)
|
||||||
assert not _is_field_useful(foo, "y", foo.y)
|
assert not _is_field_useful(foo, "y", foo.y)
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(Serializable):
|
||||||
|
bar: int
|
||||||
|
baz: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_deserialization() -> None:
|
||||||
|
foo = Foo(bar=1, baz="hello")
|
||||||
|
assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"]
|
||||||
|
serialized_foo = dumpd(foo)
|
||||||
|
assert serialized_foo == {
|
||||||
|
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
|
||||||
|
"kwargs": {"bar": 1, "baz": "hello"},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
}
|
||||||
|
new_foo = load(serialized_foo, valid_namespaces=["tests"])
|
||||||
|
assert new_foo == foo
|
||||||
|
|
||||||
|
|
||||||
|
class Foo2(Serializable):
|
||||||
|
bar: int
|
||||||
|
baz: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_deserialization_with_additional_imports() -> None:
|
||||||
|
foo = Foo(bar=1, baz="hello")
|
||||||
|
assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"]
|
||||||
|
serialized_foo = dumpd(foo)
|
||||||
|
assert serialized_foo == {
|
||||||
|
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
|
||||||
|
"kwargs": {"bar": 1, "baz": "hello"},
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
}
|
||||||
|
new_foo = load(
|
||||||
|
serialized_foo,
|
||||||
|
valid_namespaces=["tests"],
|
||||||
|
additional_import_mappings={
|
||||||
|
("tests", "unit_tests", "load", "test_serializable", "Foo"): (
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"load",
|
||||||
|
"test_serializable",
|
||||||
|
"Foo2",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert isinstance(new_foo, Foo2)
|
||||||
|
Loading…
Reference in New Issue
Block a user