From caf4ae3a45628ba748d7e16914c4fb91305fb31f Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 4 Sep 2024 11:28:04 -0700 Subject: [PATCH] fmt --- .../tests/unit_tests/load/test_serializable.py | 10 ++++++++++ libs/core/langchain_core/load/mapping.py | 2 +- .../mistralai/langchain_mistralai/chat_models.py | 4 +++- .../chat_models/__snapshots__/test_azure_standard.ambr | 7 +++++++ 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/libs/community/tests/unit_tests/load/test_serializable.py b/libs/community/tests/unit_tests/load/test_serializable.py index 7ca4d27da91..d734cc49bd8 100644 --- a/libs/community/tests/unit_tests/load/test_serializable.py +++ b/libs/community/tests/unit_tests/load/test_serializable.py @@ -102,6 +102,16 @@ def test_serializable_mapping() -> None: "modifier", "RemoveMessage", ), + ("langchain", "chat_models", "mistralai", "MistralAI"): ( + "langchain_mistralai", + "chat_models", + "ChatMistralAI", + ), + ("langchain_groq", "chat_models", "ChatGroq"): ( + "langchain_groq", + "chat_models", + "ChatGroq", + ), } serializable_modules = import_all_modules("langchain") diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index c884d65ac55..4aa70dead92 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -274,7 +274,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { ("langchain_groq", "chat_models", "ChatGroq"): ( "langchain_groq", "chat_models", - "ChatGroq" + "ChatGroq", ), ("langchain", "chat_models", "fireworks", "ChatFireworks"): ( "langchain_fireworks", diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 7c0a75385be..b3a43629e8a 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -362,7 +362,9 @@ class ChatMistralAI(BaseChatModel): """A chat model that uses the MistralAI API.""" client: httpx.Client = Field(default=None, exclude=True) #: :meta private: - async_client: httpx.AsyncClient = Field(default=None, exclude=True) #: :meta private: + async_client: httpx.AsyncClient = Field( + default=None, exclude=True + ) #: :meta private: mistral_api_key: Optional[SecretStr] = Field( alias="api_key", default_factory=secret_from_env("MISTRAL_API_KEY", default=None), diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 982c80d816f..8233a605fda 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -45,6 +45,13 @@ 'AzureChatOpenAI', ]), 'kwargs': dict({ + 'azure_ad_token': dict({ + 'id': list([ + 'AZURE_OPENAI_AD_TOKEN', + ]), + 'lc': 1, + 'type': 'secret', + }), 'azure_endpoint': 'https://test.azure.com', 'deployment_name': 'test', 'max_retries': 2,