This commit is contained in:
Bagatur
2024-09-04 11:28:04 -07:00
parent fdf6fbde18
commit dba308447d
5 changed files with 26 additions and 3 deletions

View File

@@ -102,6 +102,16 @@ def test_serializable_mapping() -> None:
"modifier", "modifier",
"RemoveMessage", "RemoveMessage",
), ),
("langchain", "chat_models", "mistralai", "MistralAI"): (
"langchain_mistralai",
"chat_models",
"ChatMistralAI",
),
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq",
),
} }
serializable_modules = import_all_modules("langchain") serializable_modules = import_all_modules("langchain")

View File

@@ -274,7 +274,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
("langchain_groq", "chat_models", "ChatGroq"): ( ("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq", "langchain_groq",
"chat_models", "chat_models",
"ChatGroq" "ChatGroq",
), ),
("langchain", "chat_models", "fireworks", "ChatFireworks"): ( ("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain_fireworks", "langchain_fireworks",

View File

@@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
@@ -18,3 +18,7 @@ class TestFireworksStandard(ChatModelUnitTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"api_key": "test_api_key"} return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return ({"FIREWORKS_API_KEY": "api_key"}, {}, {"fireworks_api_key": "api_key"})

View File

@@ -360,7 +360,9 @@ class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API.""" """A chat model that uses the MistralAI API."""
client: httpx.Client = Field(default=None, exclude=True) #: :meta private: client: httpx.Client = Field(default=None, exclude=True) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None, exclude=True) #: :meta private: async_client: httpx.AsyncClient = Field(
default=None, exclude=True
) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field( mistral_api_key: Optional[SecretStr] = Field(
alias="api_key", alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None), default_factory=secret_from_env("MISTRAL_API_KEY", default=None),

View File

@@ -45,6 +45,13 @@
'AzureChatOpenAI', 'AzureChatOpenAI',
]), ]),
'kwargs': dict({ 'kwargs': dict({
'azure_ad_token': dict({
'id': list([
'AZURE_OPENAI_AD_TOKEN',
]),
'lc': 1,
'type': 'secret',
}),
'azure_endpoint': 'https://test.azure.com', 'azure_endpoint': 'https://test.azure.com',
'deployment_name': 'test', 'deployment_name': 'test',
'max_retries': 2, 'max_retries': 2,