blacklist community load from path

This commit is contained in:
Bagatur
2024-09-11 18:13:44 -07:00
parent e1d194160a
commit dff97da6ab

View File

@@ -21,6 +21,7 @@ DEFAULT_NAMESPACES = [
"langchain_google_genai", "langchain_google_genai",
"langchain_aws", "langchain_aws",
] ]
DISALLOW_LOAD_FROM_PATH = ["langchain_community"]
ALL_SERIALIZABLE_MAPPINGS = { ALL_SERIALIZABLE_MAPPINGS = {
**SERIALIZABLE_MAPPING, **SERIALIZABLE_MAPPING,
@@ -98,17 +99,24 @@ class Reviver:
raise ValueError(f"Invalid namespace: {value}") raise ValueError(f"Invalid namespace: {value}")
# If namespace is in mapping, used custom path # If namespace is in mapping, used custom path
mapping_key = tuple(namespace + [name])
if ( if (
namespace[0] in DEFAULT_NAMESPACES namespace[0] in DEFAULT_NAMESPACES
and (key := tuple(namespace + [name])) in ALL_SERIALIZABLE_MAPPINGS and mapping_key in ALL_SERIALIZABLE_MAPPINGS
): ):
import_path = ALL_SERIALIZABLE_MAPPINGS[key] import_path = ALL_SERIALIZABLE_MAPPINGS[mapping_key]
# Split into module and name # Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1] import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module # Import module
mod = importlib.import_module(".".join(import_dir)) mod = importlib.import_module(".".join(import_dir))
# Import class # Import class
cls = getattr(mod, import_obj) cls = getattr(mod, import_obj)
elif namespace[0] in DISALLOW_LOAD_FROM_PATH:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{mapping_key}."
)
# Otherwise, load by path # Otherwise, load by path
else: else:
mod = importlib.import_module(".".join(namespace)) mod = importlib.import_module(".".join(namespace))