core[patch]: support load from path for default namespaces (#26675)

This commit is contained in:
Bagatur 2024-09-19 14:47:27 -07:00 committed by GitHub
parent e8236e58f2
commit 409f35363b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,6 +19,17 @@ DEFAULT_NAMESPACES = [
"langchain_anthropic", "langchain_anthropic",
"langchain_groq", "langchain_groq",
"langchain_google_genai", "langchain_google_genai",
"langchain_aws",
"langchain_openai",
"langchain_google_vertexai",
"langchain_mistralai",
"langchain_fireworks",
]
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
# Load by path is not allowed.
DISALLOW_LOAD_FROM_PATH = [
"langchain_community",
"langchain",
] ]
ALL_SERIALIZABLE_MAPPINGS = { ALL_SERIALIZABLE_MAPPINGS = {
@ -103,39 +114,30 @@ class Reviver:
and value.get("id", None) is not None and value.get("id", None) is not None
): ):
[*namespace, name] = value["id"] [*namespace, name] = value["id"]
mapping_key = tuple(value["id"])
if namespace[0] not in self.valid_namespaces: if namespace[0] not in self.valid_namespaces:
raise ValueError(f"Invalid namespace: {value}") raise ValueError(f"Invalid namespace: {value}")
# The root namespace ["langchain"] is not a valid identifier.
# The root namespace "langchain" is not a valid identifier. elif namespace == ["langchain"]:
if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}") raise ValueError(f"Invalid namespace: {value}")
# Has explicit import path.
# If namespace is in known namespaces, try to use mapping elif mapping_key in self.import_mappings:
key = tuple(namespace + [name]) import_path = self.import_mappings[mapping_key]
if namespace[0] in DEFAULT_NAMESPACES: # Split into module and name
# Get the importable path import_dir, name = import_path[:-1], import_path[-1]
if key not in self.import_mappings: # Import module
mod = importlib.import_module(".".join(import_dir))
elif namespace[0] in DISALLOW_LOAD_FROM_PATH:
raise ValueError( raise ValueError(
"Trying to deserialize something that cannot " "Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: " "be deserialized in current version of langchain-core: "
f"{key}" f"{mapping_key}."
) )
import_path = self.import_mappings[key] # Otherwise, treat namespace as path.
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
# Otherwise, load by path
else:
if key in self.additional_import_mappings:
import_path = self.import_mappings[key]
mod = importlib.import_module(".".join(import_path[:-1]))
name = import_path[-1]
else: else:
mod = importlib.import_module(".".join(namespace)) mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name) cls = getattr(mod, name)
# The class must be a subclass of Serializable. # The class must be a subclass of Serializable.