This commit is contained in:
Bagatur
2024-09-04 11:28:04 -07:00
parent c88b75ca6a
commit caf4ae3a45
4 changed files with 21 additions and 2 deletions

View File

@@ -102,6 +102,16 @@ def test_serializable_mapping() -> None:
"modifier",
"RemoveMessage",
),
("langchain", "chat_models", "mistralai", "MistralAI"): (
"langchain_mistralai",
"chat_models",
"ChatMistralAI",
),
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq",
),
}
serializable_modules = import_all_modules("langchain")

View File

@@ -274,7 +274,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq"
"ChatGroq",
),
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain_fireworks",

View File

@@ -362,7 +362,9 @@ class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
client: httpx.Client = Field(default=None, exclude=True) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None, exclude=True) #: :meta private:
async_client: httpx.AsyncClient = Field(
default=None, exclude=True
) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),

View File

@@ -45,6 +45,13 @@
'AzureChatOpenAI',
]),
'kwargs': dict({
'azure_ad_token': dict({
'id': list([
'AZURE_OPENAI_AD_TOKEN',
]),
'lc': 1,
'type': 'secret',
}),
'azure_endpoint': 'https://test.azure.com',
'deployment_name': 'test',
'max_retries': 2,