standard-tests: rename langchain_standard_tests to langchain_tests, release 0.3.2 (#28203)

This commit is contained in:
Erick Friis
2024-11-18 19:10:39 -08:00
committed by GitHub
parent 24eea2e398
commit 0dbaf05bb7
60 changed files with 70 additions and 83 deletions

View File

@@ -0,0 +1,269 @@
"""Unit tests for chat models."""
import os
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.load import dumpd, load
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import (
BaseModel as BaseModelV1,
)
from pydantic.v1 import (
Field as FieldV1,
)
from pydantic.v1 import (
ValidationError as ValidationErrorV1,
)
from syrupy import SnapshotAssertion
from langchain_tests.base import BaseStandardTests
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
class PersonB(BaseModelV1):
"""Record attributes of a person."""
name: str = FieldV1(..., description="The name of the person.")
age: int = FieldV1(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
class PersonA(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
def my_adder(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
class ChatModelTests(BaseStandardTests):
@property
@abstractmethod
def chat_model_class(self) -> Type[BaseChatModel]: ...
@property
def chat_model_params(self) -> dict:
return {}
@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"timeout": 60,
"stop": [],
"max_retries": 2,
}
@pytest.fixture
def model(self) -> BaseChatModel:
return self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
@property
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return None
@property
def has_structured_output(self) -> bool:
return (
self.chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
@property
def supports_image_inputs(self) -> bool:
return False
@property
def supports_video_inputs(self) -> bool:
return False
@property
def returns_usage_metadata(self) -> bool:
return True
@property
def supports_anthropic_inputs(self) -> bool:
return False
@property
def supports_image_tool_message(self) -> bool:
return False
@property
def supported_usage_metadata_details(
self,
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return {"invoke": [], "stream": []}
class ChatModelUnitTests(ChatModelTests):
@property
def standard_chat_model_params(self) -> dict:
params = super().standard_chat_model_params
params["api_key"] = "test"
return params
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None
def test_init_from_env(self) -> None:
env_params, model_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.chat_model_class(**model_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected
def test_init_streaming(
self,
) -> None:
model = self.chat_model_class(
**{
**self.standard_chat_model_params,
**self.chat_model_params,
"streaming": True,
}
)
assert model is not None
def test_bind_tool_pydantic(
self,
model: BaseChatModel,
) -> None:
if not self.has_tool_calling:
return
tools = [my_adder_tool, my_adder]
for pydantic_model in TEST_PYDANTIC_MODELS:
model_schema = (
pydantic_model.model_json_schema()
if hasattr(pydantic_model, "model_json_schema")
else pydantic_model.schema()
)
tools.extend([pydantic_model, model_schema])
# Doing a mypy ignore here since some of the tools are from pydantic
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
# so type checking does not become annoying to users.
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output(
self,
model: BaseChatModel,
schema: Any,
) -> None:
if not self.has_structured_output:
return
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
class ExpectedParams(BaseModelV1):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params) # type: ignore
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")
# Test optional params
model = self.chat_model_class(
max_tokens=10,
stop=["test"],
**self.chat_model_params, # type: ignore
)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params) # type: ignore
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")
def test_serdes(self, model: BaseChatModel, snapshot: SnapshotAssertion) -> None:
if not self.chat_model_class.is_lc_serializable():
return
env_params, model_params, expected_attrs = self.init_from_env_params
with mock.patch.dict(os.environ, env_params):
ser = dumpd(model)
assert ser == snapshot(name="serialized")
assert model.dict() == load(dumpd(model)).dict()