community[patch]: cross_encoders flatten namespaces (#20183)

Issue `langchain_community.cross_encoders` didn't have flattening
namespace code in the __init__.py file.
Changes:
- added code to flattening namespaces (used #20050 as a template)
- added ut for a change
- added missed `test_imports` for `chat_loaders` and
`chat_message_histories` modules
This commit is contained in:
Leonid Ganeline 2024-04-08 17:50:23 -07:00 committed by GitHub
parent 1af7133828
commit 2f8dd1a161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 12 deletions

View File

@ -1,4 +1,4 @@
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
services.
**Cross encoder models** can be LLMs or not.
@ -9,18 +9,22 @@
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
"""
import importlib
from typing import TYPE_CHECKING, Any
import logging
from langchain_community.cross_encoders.base import BaseCrossEncoder
from langchain_community.cross_encoders.fake import FakeCrossEncoder
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
from langchain_community.cross_encoders.sagemaker_endpoint import (
SagemakerEndpointCrossEncoder,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from langchain_community.cross_encoders.base import (
BaseCrossEncoder, # noqa: F401
)
from langchain_community.cross_encoders.fake import (
FakeCrossEncoder, # noqa: F401
)
from langchain_community.cross_encoders.huggingface import (
HuggingFaceCrossEncoder, # noqa: F401
)
from langchain_community.cross_encoders.sagemaker_endpoint import (
SagemakerEndpointCrossEncoder, # noqa: F401
)
__all__ = [
"BaseCrossEncoder",
@ -28,3 +32,17 @@ __all__ = [
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
]
_module_lookup = {
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@ -0,0 +1,18 @@
from langchain_community.chat_loaders import _module_lookup
EXPECTED_ALL = [
"BaseChatLoader",
"FolderFacebookMessengerChatLoader",
"GMailLoader",
"IMessageChatLoader",
"LangSmithDatasetChatLoader",
"LangSmithRunChatLoader",
"SingleFileFacebookMessengerChatLoader",
"SlackChatLoader",
"TelegramChatLoader",
"WhatsAppChatLoader",
]
def test_all_imports() -> None:
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)

View File

@ -0,0 +1,29 @@
from langchain_community.chat_message_histories import _module_lookup
EXPECTED_ALL = [
"AstraDBChatMessageHistory",
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
"FirestoreChatMessageHistory",
"MomentoChatMessageHistory",
"MongoDBChatMessageHistory",
"Neo4jChatMessageHistory",
"PostgresChatMessageHistory",
"RedisChatMessageHistory",
"RocksetChatMessageHistory",
"SQLChatMessageHistory",
"SingleStoreDBChatMessageHistory",
"StreamlitChatMessageHistory",
"TiDBChatMessageHistory",
"UpstashRedisChatMessageHistory",
"XataChatMessageHistory",
"ZepChatMessageHistory",
]
def test_all_imports() -> None:
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)

View File

@ -0,0 +1,14 @@
from langchain_community.cross_encoders import __all__, _module_lookup
EXPECTED_ALL = [
"BaseCrossEncoder",
"FakeCrossEncoder",
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
]
def test_all_imports() -> None:
"""Test that __all__ is correctly set."""
assert set(__all__) == set(EXPECTED_ALL)
assert set(__all__) == set(_module_lookup.keys())