standard-tests[patch]: test init from env vars (#25983)

This commit is contained in:
Bagatur 2024-09-03 12:05:39 -07:00 committed by GitHub
parent ac922105ad
commit bc3b02651c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 5 deletions

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -12,3 +12,21 @@ class TestOpenAIStandard(ChatModelUnitTests):
@property @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",
"OPENAI_ORGANIZATION": "org_id",
"OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com",
},
{},
{
"openai_api_key": "api_key",
"openai_organization": "org_id",
"openai_api_base": "api_base",
"openai_proxy": "https://proxy.com",
},
)

View File

@ -1,11 +1,12 @@
"""Unit tests for chat models.""" """Unit tests for chat models."""
import os
from abc import abstractmethod from abc import abstractmethod
from typing import Any, List, Literal, Optional, Type from typing import Any, List, Literal, Optional, Tuple, Type
from unittest import mock
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool
@ -132,12 +133,30 @@ class ChatModelUnitTests(ChatModelTests):
params["api_key"] = "test" params["api_key"] = "test"
return params return params
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init(self) -> None: def test_init(self) -> None:
model = self.chat_model_class( model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params} **{**self.standard_chat_model_params, **self.chat_model_params}
) )
assert model is not None assert model is not None
def test_init_from_env(self) -> None:
env_params, model_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.chat_model_class(**model_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected
def test_init_streaming( def test_init_streaming(
self, self,
) -> None: ) -> None:

View File

@ -1,8 +1,11 @@
import os
from abc import abstractmethod from abc import abstractmethod
from typing import Type from typing import Tuple, Type
from unittest import mock
import pytest import pytest
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr
from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.base import BaseStandardTests
@ -26,3 +29,21 @@ class EmbeddingsUnitTests(EmbeddingsTests):
def test_init(self) -> None: def test_init(self) -> None:
model = self.embeddings_class(**self.embedding_model_params) model = self.embeddings_class(**self.embedding_model_params)
assert model is not None assert model is not None
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init_from_env(self) -> None:
env_params, embeddings_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.embeddings_class(**embeddings_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected