standard-tests[patch]: test init from env vars (#25983)

This commit is contained in:
Bagatur 2024-09-03 12:05:39 -07:00 committed by GitHub
parent ac922105ad
commit bc3b02651c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 5 deletions

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Type
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -12,3 +12,21 @@ class TestOpenAIStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",
"OPENAI_ORGANIZATION": "org_id",
"OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com",
},
{},
{
"openai_api_key": "api_key",
"openai_organization": "org_id",
"openai_api_base": "api_base",
"openai_proxy": "https://proxy.com",
},
)

View File

@ -1,11 +1,12 @@
"""Unit tests for chat models."""
import os
from abc import abstractmethod
from typing import Any, List, Literal, Optional, Type
from typing import Any, List, Literal, Optional, Tuple, Type
from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
@ -132,12 +133,30 @@ class ChatModelUnitTests(ChatModelTests):
params["api_key"] = "test"
return params
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None
def test_init_from_env(self) -> None:
env_params, model_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.chat_model_class(**model_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected
def test_init_streaming(
self,
) -> None:

View File

@ -1,8 +1,11 @@
import os
from abc import abstractmethod
from typing import Type
from typing import Tuple, Type
from unittest import mock
import pytest
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr
from langchain_standard_tests.base import BaseStandardTests
@ -26,3 +29,21 @@ class EmbeddingsUnitTests(EmbeddingsTests):
def test_init(self) -> None:
model = self.embeddings_class(**self.embedding_model_params)
assert model is not None
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init_from_env(self) -> None:
env_params, embeddings_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.embeddings_class(**embeddings_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected