diff --git a/libs/community/tests/unit_tests/load/test_serializable.py b/libs/community/tests/unit_tests/load/test_serializable.py index 7ca4d27da91..d734cc49bd8 100644 --- a/libs/community/tests/unit_tests/load/test_serializable.py +++ b/libs/community/tests/unit_tests/load/test_serializable.py @@ -102,6 +102,16 @@ def test_serializable_mapping() -> None: "modifier", "RemoveMessage", ), + ("langchain", "chat_models", "mistralai", "MistralAI"): ( + "langchain_mistralai", + "chat_models", + "ChatMistralAI", + ), + ("langchain_groq", "chat_models", "ChatGroq"): ( + "langchain_groq", + "chat_models", + "ChatGroq", + ), } serializable_modules = import_all_modules("langchain") diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index c884d65ac55..4aa70dead92 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -274,7 +274,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { ("langchain_groq", "chat_models", "ChatGroq"): ( "langchain_groq", "chat_models", - "ChatGroq" + "ChatGroq", ), ("langchain", "chat_models", "fireworks", "ChatFireworks"): ( "langchain_fireworks", diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 9288aeeb9f8..9d13e19d1ab 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Tuple, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found] @@ -18,3 +18,7 @@ class TestFireworksStandard(ChatModelUnitTests): @property def chat_model_params(self) -> dict: return {"api_key": "test_api_key"} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ({"FIREWORKS_API_KEY": "api_key"}, {}, {"fireworks_api_key": "api_key"}) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 6a497bf91f0..54485ebc55c 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -360,7 +360,9 @@ class ChatMistralAI(BaseChatModel): """A chat model that uses the MistralAI API.""" client: httpx.Client = Field(default=None, exclude=True) #: :meta private: - async_client: httpx.AsyncClient = Field(default=None, exclude=True) #: :meta private: + async_client: httpx.AsyncClient = Field( + default=None, exclude=True + ) #: :meta private: mistral_api_key: Optional[SecretStr] = Field( alias="api_key", default_factory=secret_from_env("MISTRAL_API_KEY", default=None), diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 982c80d816f..8233a605fda 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -45,6 +45,13 @@ 'AzureChatOpenAI', ]), 'kwargs': dict({ + 'azure_ad_token': dict({ + 'id': list([ + 'AZURE_OPENAI_AD_TOKEN', + ]), + 'lc': 1, + 'type': 'secret', + }), 'azure_endpoint': 'https://test.azure.com', 'deployment_name': 'test', 'max_retries': 2,