This commit is contained in:
Bagatur
2024-09-03 17:46:41 -07:00
parent 6aac2eeab5
commit 56163481dd
2 changed files with 45 additions and 1 deletions

View File

@@ -0,0 +1,30 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests
from langchain_fireworks import FireworksEmbeddings
class TestFireworksStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
return FireworksEmbeddings
@property
def embeddings_params(self) -> dict:
return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
},
{},
{
"fireworks_api_key": "api_key",
},
)

View File

@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Type
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
@@ -18,3 +18,17 @@ class TestFireworksStandard(ChatModelUnitTests):
@property
def chat_model_params(self) -> dict:
return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
"FIREWORKS_API_BASE": "https://base.com",
},
{},
{
"fireworks_api_key": "api_key",
"fireworks_api_base": "https://base.com",
},
)