Compare commits

...

4 Commits

Author SHA1 Message Date
Eugene Yurtsev
49d12d5e7e Merge branch 'master' into eugene/add_tests_for_pydantic_models 2024-07-22 10:29:41 -04:00
Eugene Yurtsev
59a7c048d3 x 2024-07-18 12:49:02 -04:00
Eugene Yurtsev
61eb096a5f Update 2024-07-18 12:41:55 -04:00
Eugene Yurtsev
f53528db71 x 2024-07-18 12:28:22 -04:00
3 changed files with 62 additions and 6 deletions

View File

@@ -1,20 +1,57 @@
"""Unit tests for chat models."""
from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel):
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
from pydantic import BaseModel, Field
class PersonA(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
@@ -112,12 +149,15 @@ class ChatModelUnitTests(ChatModelTests):
if not self.has_tool_calling:
return
tool_model = model.bind_tools(
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
)
tools = [my_adder_tool, my_adder]
for pydantic_model in TEST_PYDANTIC_MODELS:
tools.extend([pydantic_model, pydantic_model.schema()])
tool_model = model.bind_tools(tools, tool_choice="any")
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", [Person, Person.schema()])
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output(
self,
model: BaseChatModel,
@@ -129,6 +169,8 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str

View File

@@ -0,0 +1,14 @@
"""Utilities for working with pydantic models."""
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()