mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
langchain[patch]: Fix how the serializable test identifies serializable objects (#21165)
dir() will not work if we're using optional imports. The only way to do this is by using contents of __all__
This commit is contained in:
parent
23c5d87311
commit
2914abd747
@ -1,5 +1,7 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from types import ModuleType
|
||||
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
|
||||
@ -8,37 +10,71 @@ def import_all_modules(package_name: str) -> dict:
|
||||
package = importlib.import_module(package_name)
|
||||
classes: dict = {}
|
||||
|
||||
for attribute_name in dir(package):
|
||||
attribute = getattr(package, attribute_name)
|
||||
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
|
||||
def _handle_module(module: ModuleType) -> None:
|
||||
# Iterate over all members of the module
|
||||
|
||||
names = dir(module)
|
||||
|
||||
if hasattr(module, "__all__"):
|
||||
names += list(module.__all__)
|
||||
|
||||
names = sorted(set(names))
|
||||
|
||||
for name in names:
|
||||
# Check if it's a class or function
|
||||
attr = getattr(module, name)
|
||||
|
||||
if not inspect.isclass(attr):
|
||||
continue
|
||||
|
||||
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(attribute.is_lc_serializable(), bool) # type: ignore
|
||||
and attribute.is_lc_serializable() # type: ignore
|
||||
isinstance(attr.is_lc_serializable(), bool) # type: ignore
|
||||
and attr.is_lc_serializable() # type: ignore
|
||||
):
|
||||
key = tuple(attribute.lc_id()) # type: ignore
|
||||
value = tuple(attribute.__module__.split(".") + [attribute.__name__])
|
||||
key = tuple(attr.lc_id()) # type: ignore
|
||||
value = tuple(attr.__module__.split(".") + [attr.__name__])
|
||||
if key in classes and classes[key] != value:
|
||||
raise ValueError
|
||||
classes[key] = value
|
||||
if hasattr(package, "__path__"):
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
package.__path__, package_name + "."
|
||||
):
|
||||
if module_name not in (
|
||||
"langchain.chains.llm_bash",
|
||||
"langchain.chains.llm_symbolic_math",
|
||||
"langchain_community.tools.python",
|
||||
"langchain_community.vectorstores._pgvector_data_models",
|
||||
):
|
||||
importlib.import_module(module_name)
|
||||
new_classes = import_all_modules(module_name)
|
||||
for k, v in new_classes.items():
|
||||
if k in classes and classes[k] != v:
|
||||
raise ValueError
|
||||
classes[k] = v
|
||||
|
||||
_handle_module(package)
|
||||
|
||||
for importer, modname, ispkg in pkgutil.walk_packages(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
_handle_module(module)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def test_import_all_modules() -> None:
|
||||
"""Test import all modules works as expected"""
|
||||
all_modules = import_all_modules("langchain")
|
||||
filtered_modules = [
|
||||
k
|
||||
for k in all_modules
|
||||
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
|
||||
]
|
||||
# This test will need to be updated if new serializable classes are added
|
||||
# to community
|
||||
assert filtered_modules == [
|
||||
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
|
||||
("langchain", "chat_models", "bedrock", "BedrockChat"),
|
||||
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
|
||||
("langchain", "chat_models", "fireworks", "ChatFireworks"),
|
||||
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
|
||||
("langchain", "chat_models", "openai", "ChatOpenAI"),
|
||||
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
|
||||
]
|
||||
|
||||
|
||||
def test_serializable_mapping() -> None:
|
||||
to_skip = {
|
||||
# This should have had a different namespace, as it was never
|
||||
@ -59,6 +95,7 @@ def test_serializable_mapping() -> None:
|
||||
),
|
||||
}
|
||||
serializable_modules = import_all_modules("langchain")
|
||||
|
||||
missing = set(SERIALIZABLE_MAPPING).difference(
|
||||
set(serializable_modules).union(to_skip)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user