langchain[patch]: Fix how the serializable test identifies serializable objects (#21165)

dir() will not work if we're using optional imports. The only way to do this is by using contents of __all__
This commit is contained in:
Eugene Yurtsev 2024-05-01 15:56:11 -04:00 committed by GitHub
parent 23c5d87311
commit 2914abd747
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,7 @@
import importlib import importlib
import inspect
import pkgutil import pkgutil
from types import ModuleType
from langchain_core.load.mapping import SERIALIZABLE_MAPPING from langchain_core.load.mapping import SERIALIZABLE_MAPPING
@ -8,37 +10,71 @@ def import_all_modules(package_name: str) -> dict:
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
classes: dict = {} classes: dict = {}
for attribute_name in dir(package): def _handle_module(module: ModuleType) -> None:
attribute = getattr(package, attribute_name) # Iterate over all members of the module
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
names = dir(module)
if hasattr(module, "__all__"):
names += list(module.__all__)
names = sorted(set(names))
for name in names:
# Check if it's a class or function
attr = getattr(module, name)
if not inspect.isclass(attr):
continue
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
continue
if ( if (
isinstance(attribute.is_lc_serializable(), bool) # type: ignore isinstance(attr.is_lc_serializable(), bool) # type: ignore
and attribute.is_lc_serializable() # type: ignore and attr.is_lc_serializable() # type: ignore
): ):
key = tuple(attribute.lc_id()) # type: ignore key = tuple(attr.lc_id()) # type: ignore
value = tuple(attribute.__module__.split(".") + [attribute.__name__]) value = tuple(attr.__module__.split(".") + [attr.__name__])
if key in classes and classes[key] != value: if key in classes and classes[key] != value:
raise ValueError raise ValueError
classes[key] = value classes[key] = value
if hasattr(package, "__path__"):
for loader, module_name, is_pkg in pkgutil.walk_packages( _handle_module(package)
package.__path__, package_name + "."
): for importer, modname, ispkg in pkgutil.walk_packages(
if module_name not in ( package.__path__, package.__name__ + "."
"langchain.chains.llm_bash", ):
"langchain.chains.llm_symbolic_math", try:
"langchain_community.tools.python", module = importlib.import_module(modname)
"langchain_community.vectorstores._pgvector_data_models", except ModuleNotFoundError:
): continue
importlib.import_module(module_name) _handle_module(module)
new_classes = import_all_modules(module_name)
for k, v in new_classes.items():
if k in classes and classes[k] != v:
raise ValueError
classes[k] = v
return classes return classes
def test_import_all_modules() -> None:
"""Test import all modules works as expected"""
all_modules = import_all_modules("langchain")
filtered_modules = [
k
for k in all_modules
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
]
# This test will need to be updated if new serializable classes are added
# to community
assert filtered_modules == [
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
("langchain", "chat_models", "bedrock", "BedrockChat"),
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
("langchain", "chat_models", "fireworks", "ChatFireworks"),
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
("langchain", "chat_models", "openai", "ChatOpenAI"),
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
]
def test_serializable_mapping() -> None: def test_serializable_mapping() -> None:
to_skip = { to_skip = {
# This should have had a different namespace, as it was never # This should have had a different namespace, as it was never
@ -59,6 +95,7 @@ def test_serializable_mapping() -> None:
), ),
} }
serializable_modules = import_all_modules("langchain") serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference( missing = set(SERIALIZABLE_MAPPING).difference(
set(serializable_modules).union(to_skip) set(serializable_modules).union(to_skip)
) )