mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
community[patch]: cross_encoders
flatten namespaces (#20183)
Issue `langchain_community.cross_encoders` didn't have flattening namespace code in the __init__.py file. Changes: - added code to flattening namespaces (used #20050 as a template) - added ut for a change - added missed `test_imports` for `chat_loaders` and `chat_message_histories` modules
This commit is contained in:
parent
1af7133828
commit
2f8dd1a161
@ -1,4 +1,4 @@
|
||||
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
|
||||
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
|
||||
services.
|
||||
|
||||
**Cross encoder models** can be LLMs or not.
|
||||
@ -9,18 +9,22 @@
|
||||
|
||||
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
|
||||
"""
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
from langchain_community.cross_encoders.fake import FakeCrossEncoder
|
||||
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
|
||||
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
||||
SagemakerEndpointCrossEncoder,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.cross_encoders.base import (
|
||||
BaseCrossEncoder, # noqa: F401
|
||||
)
|
||||
from langchain_community.cross_encoders.fake import (
|
||||
FakeCrossEncoder, # noqa: F401
|
||||
)
|
||||
from langchain_community.cross_encoders.huggingface import (
|
||||
HuggingFaceCrossEncoder, # noqa: F401
|
||||
)
|
||||
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
||||
SagemakerEndpointCrossEncoder, # noqa: F401
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseCrossEncoder",
|
||||
@ -28,3 +32,17 @@ __all__ = [
|
||||
"HuggingFaceCrossEncoder",
|
||||
"SagemakerEndpointCrossEncoder",
|
||||
]
|
||||
|
||||
_module_lookup = {
|
||||
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
|
||||
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
|
||||
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
|
||||
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
18
libs/community/tests/unit_tests/chat_loaders/test_imports.py
Normal file
18
libs/community/tests/unit_tests/chat_loaders/test_imports.py
Normal file
@ -0,0 +1,18 @@
|
||||
from langchain_community.chat_loaders import _module_lookup
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseChatLoader",
|
||||
"FolderFacebookMessengerChatLoader",
|
||||
"GMailLoader",
|
||||
"IMessageChatLoader",
|
||||
"LangSmithDatasetChatLoader",
|
||||
"LangSmithRunChatLoader",
|
||||
"SingleFileFacebookMessengerChatLoader",
|
||||
"SlackChatLoader",
|
||||
"TelegramChatLoader",
|
||||
"WhatsAppChatLoader",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)
|
@ -0,0 +1,29 @@
|
||||
from langchain_community.chat_message_histories import _module_lookup
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AstraDBChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"ChatMessageHistory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"FirestoreChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"Neo4jChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"RedisChatMessageHistory",
|
||||
"RocksetChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"TiDBChatMessageHistory",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(_module_lookup.keys()) == set(EXPECTED_ALL)
|
@ -0,0 +1,14 @@
|
||||
from langchain_community.cross_encoders import __all__, _module_lookup
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseCrossEncoder",
|
||||
"FakeCrossEncoder",
|
||||
"HuggingFaceCrossEncoder",
|
||||
"SagemakerEndpointCrossEncoder",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
"""Test that __all__ is correctly set."""
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
assert set(__all__) == set(_module_lookup.keys())
|
Loading…
Reference in New Issue
Block a user