diff --git a/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py new file mode 100644 index 00000000000..ea8d16f92d0 --- /dev/null +++ b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py @@ -0,0 +1,30 @@ +"""Standard LangChain interface tests""" + +from typing import Tuple, Type + +from langchain_core.embeddings import Embeddings +from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests + +from langchain_fireworks import FireworksEmbeddings + + +class TestFireworksStandard(EmbeddingsUnitTests): + @property + def embeddings_class(self) -> Type[Embeddings]: + return FireworksEmbeddings + + @property + def embeddings_params(self) -> dict: + return {"api_key": "test_api_key"} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "FIREWORKS_API_KEY": "api_key", + }, + {}, + { + "fireworks_api_key": "api_key", + }, + ) diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 9288aeeb9f8..61d0d152ba8 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Tuple, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found] @@ -18,3 +18,17 @@ class TestFireworksStandard(ChatModelUnitTests): @property def chat_model_params(self) -> dict: return {"api_key": "test_api_key"} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "FIREWORKS_API_KEY": "api_key", + "FIREWORKS_API_BASE": "https://base.com", + }, + {}, + { + "fireworks_api_key": "api_key", + "fireworks_api_base": "https://base.com", + }, + )