mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
community[patch]: cross_encoders
flatten namespaces (#20183)
Issue `langchain_community.cross_encoders` didn't have flattening namespace code in the __init__.py file. Changes: - added code to flattening namespaces (used #20050 as a template) - added ut for a change - added missed `test_imports` for `chat_loaders` and `chat_message_histories` modules
This commit is contained in:
parent
1af7133828
commit
2f8dd1a161
@ -9,18 +9,22 @@
|
|||||||
|
|
||||||
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
|
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
|
||||||
"""
|
"""
|
||||||
|
import importlib
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
import logging
|
from langchain_community.cross_encoders.base import (
|
||||||
|
BaseCrossEncoder, # noqa: F401
|
||||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
)
|
||||||
from langchain_community.cross_encoders.fake import FakeCrossEncoder
|
from langchain_community.cross_encoders.fake import (
|
||||||
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
|
FakeCrossEncoder, # noqa: F401
|
||||||
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
)
|
||||||
SagemakerEndpointCrossEncoder,
|
from langchain_community.cross_encoders.huggingface import (
|
||||||
|
HuggingFaceCrossEncoder, # noqa: F401
|
||||||
|
)
|
||||||
|
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
||||||
|
SagemakerEndpointCrossEncoder, # noqa: F401
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseCrossEncoder",
|
"BaseCrossEncoder",
|
||||||
@ -28,3 +32,17 @@ __all__ = [
|
|||||||
"HuggingFaceCrossEncoder",
|
"HuggingFaceCrossEncoder",
|
||||||
"SagemakerEndpointCrossEncoder",
|
"SagemakerEndpointCrossEncoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_module_lookup = {
|
||||||
|
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
|
||||||
|
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
|
||||||
|
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
|
||||||
|
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
if name in _module_lookup:
|
||||||
|
module = importlib.import_module(_module_lookup[name])
|
||||||
|
return getattr(module, name)
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
18
libs/community/tests/unit_tests/chat_loaders/test_imports.py
Normal file
18
libs/community/tests/unit_tests/chat_loaders/test_imports.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from langchain_community.chat_loaders import _module_lookup
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseChatLoader",
|
||||||
|
"FolderFacebookMessengerChatLoader",
|
||||||
|
"GMailLoader",
|
||||||
|
"IMessageChatLoader",
|
||||||
|
"LangSmithDatasetChatLoader",
|
||||||
|
"LangSmithRunChatLoader",
|
||||||
|
"SingleFileFacebookMessengerChatLoader",
|
||||||
|
"SlackChatLoader",
|
||||||
|
"TelegramChatLoader",
|
||||||
|
"WhatsAppChatLoader",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)
|
@ -0,0 +1,29 @@
|
|||||||
|
from langchain_community.chat_message_histories import _module_lookup
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"AstraDBChatMessageHistory",
|
||||||
|
"CassandraChatMessageHistory",
|
||||||
|
"ChatMessageHistory",
|
||||||
|
"CosmosDBChatMessageHistory",
|
||||||
|
"DynamoDBChatMessageHistory",
|
||||||
|
"ElasticsearchChatMessageHistory",
|
||||||
|
"FileChatMessageHistory",
|
||||||
|
"FirestoreChatMessageHistory",
|
||||||
|
"MomentoChatMessageHistory",
|
||||||
|
"MongoDBChatMessageHistory",
|
||||||
|
"Neo4jChatMessageHistory",
|
||||||
|
"PostgresChatMessageHistory",
|
||||||
|
"RedisChatMessageHistory",
|
||||||
|
"RocksetChatMessageHistory",
|
||||||
|
"SQLChatMessageHistory",
|
||||||
|
"SingleStoreDBChatMessageHistory",
|
||||||
|
"StreamlitChatMessageHistory",
|
||||||
|
"TiDBChatMessageHistory",
|
||||||
|
"UpstashRedisChatMessageHistory",
|
||||||
|
"XataChatMessageHistory",
|
||||||
|
"ZepChatMessageHistory",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)
|
@ -0,0 +1,14 @@
|
|||||||
|
from langchain_community.cross_encoders import __all__, _module_lookup
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"BaseCrossEncoder",
|
||||||
|
"FakeCrossEncoder",
|
||||||
|
"HuggingFaceCrossEncoder",
|
||||||
|
"SagemakerEndpointCrossEncoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
"""Test that __all__ is correctly set."""
|
||||||
|
assert set(__all__) == set(EXPECTED_ALL)
|
||||||
|
assert set(__all__) == set(_module_lookup.keys())
|
Loading…
Reference in New Issue
Block a user