diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index ba150a5a009..0fe91975c46 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -19,6 +19,17 @@ DEFAULT_NAMESPACES = [ "langchain_anthropic", "langchain_groq", "langchain_google_genai", + "langchain_aws", + "langchain_openai", + "langchain_google_vertexai", + "langchain_mistralai", + "langchain_fireworks", +] +# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed. +# Load by path is not allowed. +DISALLOW_LOAD_FROM_PATH = [ + "langchain_community", + "langchain", ] ALL_SERIALIZABLE_MAPPINGS = { @@ -103,40 +114,31 @@ class Reviver: and value.get("id", None) is not None ): [*namespace, name] = value["id"] + mapping_key = tuple(value["id"]) if namespace[0] not in self.valid_namespaces: raise ValueError(f"Invalid namespace: {value}") - - # The root namespace "langchain" is not a valid identifier. - if len(namespace) == 1 and namespace[0] == "langchain": + # The root namespace ["langchain"] is not a valid identifier. + elif namespace == ["langchain"]: raise ValueError(f"Invalid namespace: {value}") - - # If namespace is in known namespaces, try to use mapping - key = tuple(namespace + [name]) - if namespace[0] in DEFAULT_NAMESPACES: - # Get the importable path - if key not in self.import_mappings: - raise ValueError( - "Trying to deserialize something that cannot " - "be deserialized in current version of langchain-core: " - f"{key}" - ) - import_path = self.import_mappings[key] + # Has explicit import path. + elif mapping_key in self.import_mappings: + import_path = self.import_mappings[mapping_key] # Split into module and name - import_dir, import_obj = import_path[:-1], import_path[-1] + import_dir, name = import_path[:-1], import_path[-1] # Import module mod = importlib.import_module(".".join(import_dir)) - # Import class - cls = getattr(mod, import_obj) - # Otherwise, load by path + elif namespace[0] in DISALLOW_LOAD_FROM_PATH: + raise ValueError( + "Trying to deserialize something that cannot " + "be deserialized in current version of langchain-core: " + f"{mapping_key}." + ) + # Otherwise, treat namespace as path. else: - if key in self.additional_import_mappings: - import_path = self.import_mappings[key] - mod = importlib.import_module(".".join(import_path[:-1])) - name = import_path[-1] - else: - mod = importlib.import_module(".".join(namespace)) - cls = getattr(mod, name) + mod = importlib.import_module(".".join(namespace)) + + cls = getattr(mod, name) # The class must be a subclass of Serializable. if not issubclass(cls, Serializable):