langchain/libs/community/tests/unit_tests/load/test_serializable.py
Christophe Bornet ada740b5b9
community: Add ruff rule PGH003 (#30812)
See https://docs.astral.sh/ruff/rules/blanket-type-ignore/

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
2025-04-14 02:32:13 +00:00

156 lines
5.1 KiB
Python

import importlib
import inspect
import pkgutil
from types import ModuleType
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
def import_all_modules(package_name: str) -> dict:
package = importlib.import_module(package_name)
classes: dict = {}
def _handle_module(module: ModuleType) -> None:
# Iterate over all members of the module
names = dir(module)
if hasattr(module, "__all__"):
names += list(module.__all__)
names = sorted(set(names))
for name in names:
# Check if it's a class or function
attr = getattr(module, name)
if not inspect.isclass(attr):
continue
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
continue
if (
isinstance(attr.is_lc_serializable(), bool)
and attr.is_lc_serializable()
):
key = tuple(attr.lc_id())
value = tuple(attr.__module__.split(".") + [attr.__name__])
if key in classes and classes[key] != value:
raise ValueError
classes[key] = value
_handle_module(package)
for importer, modname, ispkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
try:
module = importlib.import_module(modname)
except ModuleNotFoundError:
continue
_handle_module(module)
return classes
def test_import_all_modules() -> None:
"""Test import all modules works as expected"""
all_modules = import_all_modules("langchain")
filtered_modules = [
k
for k in all_modules
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
]
# This test will need to be updated if new serializable classes are added
# to community
assert sorted(filtered_modules) == sorted(
[
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
("langchain", "chat_models", "bedrock", "BedrockChat"),
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
("langchain", "chat_models", "fireworks", "ChatFireworks"),
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
("langchain", "chat_models", "openai", "ChatOpenAI"),
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
]
)
def test_serializable_mapping() -> None:
to_skip = {
# This should have had a different namespace, as it was never
# exported from the langchain module, but we keep for whoever has
# already serialized it.
("langchain", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
# This is not exported from langchain, only langchain_core
("langchain_core", "prompts", "structured", "StructuredPrompt"): (
"langchain_core",
"prompts",
"structured",
"StructuredPrompt",
),
# This is not exported from langchain, only langchain_core
("langchain", "schema", "messages", "RemoveMessage"): (
"langchain_core",
"messages",
"modifier",
"RemoveMessage",
),
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
"langchain_mistralai",
"chat_models",
"ChatMistralAI",
),
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq",
),
("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): (
"langchain_sambanova",
"chat_models",
"ChatSambaNovaCloud",
),
("langchain_sambanova", "chat_models", "ChatSambaStudio"): (
"langchain_sambanova",
"chat_models",
"ChatSambaStudio",
),
# TODO(0.3): For now we're skipping the below two tests. Need to fix
# so that it only runs when langchain-aws, langchain-google-genai
# are installed.
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
"langchain_aws",
"chat_models",
"bedrock",
"ChatBedrock",
),
("langchain_google_genai", "chat_models", "ChatGoogleGenerativeAI"): (
"langchain_google_genai",
"chat_models",
"ChatGoogleGenerativeAI",
),
}
serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference(
set(serializable_modules).union(to_skip)
)
assert missing == set()
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
assert extra == set()
for k, import_path in serializable_modules.items():
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
assert list(k) == cls.lc_id()