standard-tests: rename langchain_standard_tests to langchain_tests, release 0.3.2 (#28203)

This commit is contained in:
Erick Friis
2024-11-18 19:10:39 -08:00
committed by GitHub
parent 24eea2e398
commit 0dbaf05bb7
60 changed files with 70 additions and 83 deletions

View File

@@ -0,0 +1,20 @@
# ruff: noqa: E402
import pytest
# Rewrite assert statements for test suite so that implementations can
# see the full error message from failed asserts.
# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting
modules = [
"chat_models",
"embeddings",
"tools",
]
for module in modules:
pytest.register_assert_rewrite(f"langchain_tests.unit_tests.{module}")
from .chat_models import ChatModelUnitTests
from .embeddings import EmbeddingsUnitTests
from .tools import ToolsUnitTests
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests", "ToolsUnitTests"]

View File

@@ -0,0 +1,269 @@
"""Unit tests for chat models."""
import os
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.load import dumpd, load
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import (
BaseModel as BaseModelV1,
)
from pydantic.v1 import (
Field as FieldV1,
)
from pydantic.v1 import (
ValidationError as ValidationErrorV1,
)
from syrupy import SnapshotAssertion
from langchain_tests.base import BaseStandardTests
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
class PersonB(BaseModelV1):
"""Record attributes of a person."""
name: str = FieldV1(..., description="The name of the person.")
age: int = FieldV1(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
class PersonA(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
def my_adder(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
class ChatModelTests(BaseStandardTests):
@property
@abstractmethod
def chat_model_class(self) -> Type[BaseChatModel]: ...
@property
def chat_model_params(self) -> dict:
return {}
@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"timeout": 60,
"stop": [],
"max_retries": 2,
}
@pytest.fixture
def model(self) -> BaseChatModel:
return self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
@property
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return None
@property
def has_structured_output(self) -> bool:
return (
self.chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
@property
def supports_image_inputs(self) -> bool:
return False
@property
def supports_video_inputs(self) -> bool:
return False
@property
def returns_usage_metadata(self) -> bool:
return True
@property
def supports_anthropic_inputs(self) -> bool:
return False
@property
def supports_image_tool_message(self) -> bool:
return False
@property
def supported_usage_metadata_details(
self,
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return {"invoke": [], "stream": []}
class ChatModelUnitTests(ChatModelTests):
@property
def standard_chat_model_params(self) -> dict:
params = super().standard_chat_model_params
params["api_key"] = "test"
return params
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None
def test_init_from_env(self) -> None:
env_params, model_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.chat_model_class(**model_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected
def test_init_streaming(
self,
) -> None:
model = self.chat_model_class(
**{
**self.standard_chat_model_params,
**self.chat_model_params,
"streaming": True,
}
)
assert model is not None
def test_bind_tool_pydantic(
self,
model: BaseChatModel,
) -> None:
if not self.has_tool_calling:
return
tools = [my_adder_tool, my_adder]
for pydantic_model in TEST_PYDANTIC_MODELS:
model_schema = (
pydantic_model.model_json_schema()
if hasattr(pydantic_model, "model_json_schema")
else pydantic_model.schema()
)
tools.extend([pydantic_model, model_schema])
# Doing a mypy ignore here since some of the tools are from pydantic
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
# so type checking does not become annoying to users.
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output(
self,
model: BaseChatModel,
schema: Any,
) -> None:
if not self.has_structured_output:
return
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
class ExpectedParams(BaseModelV1):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params) # type: ignore
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")
# Test optional params
model = self.chat_model_class(
max_tokens=10,
stop=["test"],
**self.chat_model_params, # type: ignore
)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params) # type: ignore
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")
def test_serdes(self, model: BaseChatModel, snapshot: SnapshotAssertion) -> None:
if not self.chat_model_class.is_lc_serializable():
return
env_params, model_params, expected_attrs = self.init_from_env_params
with mock.patch.dict(os.environ, env_params):
ser = dumpd(model)
assert ser == snapshot(name="serialized")
assert model.dict() == load(dumpd(model)).dict()

View File

@@ -0,0 +1,48 @@
import os
from abc import abstractmethod
from typing import Tuple, Type
from unittest import mock
import pytest
from langchain_core.embeddings import Embeddings
from pydantic import SecretStr
from langchain_tests.base import BaseStandardTests
class EmbeddingsTests(BaseStandardTests):
@property
@abstractmethod
def embeddings_class(self) -> Type[Embeddings]: ...
@property
def embedding_model_params(self) -> dict:
return {}
@pytest.fixture
def model(self) -> Embeddings:
return self.embeddings_class(**self.embedding_model_params)
class EmbeddingsUnitTests(EmbeddingsTests):
def test_init(self) -> None:
model = self.embeddings_class(**self.embedding_model_params)
assert model is not None
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init_from_env(self) -> None:
env_params, embeddings_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.embeddings_class(**embeddings_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected

View File

@@ -0,0 +1,73 @@
import os
from abc import abstractmethod
from typing import Callable, Tuple, Type, Union
from unittest import mock
import pytest
from langchain_core.tools import BaseTool
from pydantic import SecretStr
from langchain_tests.base import BaseStandardTests
class ToolsTests(BaseStandardTests):
@property
@abstractmethod
def tool_constructor(self) -> Union[Type[BaseTool], Callable]: ...
@property
def tool_constructor_params(self) -> dict:
return {}
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {}
@pytest.fixture
def tool(self) -> BaseTool:
return self.tool_constructor(**self.tool_constructor_params)
class ToolsUnitTests(ToolsTests):
def test_init(self) -> None:
tool = self.tool_constructor(**self.tool_constructor_params)
assert tool is not None
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}
def test_init_from_env(self) -> None:
env_params, tools_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
tool = self.tool_constructor(**tools_params)
assert tool is not None
for k, expected in expected_attrs.items():
actual = getattr(tool, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected
def test_has_name(self, tool: BaseTool) -> None:
assert tool.name
def test_has_input_schema(self, tool: BaseTool) -> None:
assert tool.get_input_schema()
def test_input_schema_matches_invoke_params(self, tool: BaseTool) -> None:
"""
Tests that the provided example params match the declared input schema
"""
# this will be a pydantic object
input_schema = tool.get_input_schema()
assert input_schema(**self.tool_invoke_params_example)