mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 13:59:49 +00:00
fmt
This commit is contained in:
@@ -16,10 +16,10 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from pydantic import BaseModel as RawBaseModel
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import Field as RawField
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1 import Field as FieldV1
|
||||||
|
|
||||||
from langchain_standard_tests.unit_tests.chat_models import (
|
from langchain_standard_tests.unit_tests.chat_models import (
|
||||||
ChatModelTests,
|
ChatModelTests,
|
||||||
@@ -28,8 +28,8 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
|||||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
|
||||||
class MagicFunctionSchema(RawBaseModel):
|
class MagicFunctionSchema(BaseModel):
|
||||||
input: int = RawField(..., gt=-1000, lt=1000)
|
input: int = Field(..., gt=-1000, lt=1000)
|
||||||
|
|
||||||
|
|
||||||
@tool(args_schema=MagicFunctionSchema)
|
@tool(args_schema=MagicFunctionSchema)
|
||||||
@@ -240,14 +240,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
if not self.has_tool_calling:
|
if not self.has_tool_calling:
|
||||||
pytest.skip("Test requires tool calling.")
|
pytest.skip("Test requires tool calling.")
|
||||||
|
|
||||||
from pydantic import BaseModel as BaseModelProper
|
class Joke(BaseModel):
|
||||||
from pydantic import Field as FieldProper
|
|
||||||
|
|
||||||
class Joke(BaseModelProper):
|
|
||||||
"""Joke to tell user."""
|
"""Joke to tell user."""
|
||||||
|
|
||||||
setup: str = FieldProper(description="question to set up a joke")
|
setup: str = Field(description="question to set up a joke")
|
||||||
punchline: str = FieldProper(description="answer to resolve the joke")
|
punchline: str = Field(description="answer to resolve the joke")
|
||||||
|
|
||||||
# Pydantic class
|
# Pydantic class
|
||||||
# Type ignoring since the interface only officially supports pydantic 1
|
# Type ignoring since the interface only officially supports pydantic 1
|
||||||
@@ -280,11 +277,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
if not self.has_tool_calling:
|
if not self.has_tool_calling:
|
||||||
pytest.skip("Test requires tool calling.")
|
pytest.skip("Test requires tool calling.")
|
||||||
|
|
||||||
class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel
|
class Joke(BaseModelV1): # Uses langchain_core.pydantic_v1.BaseModel
|
||||||
"""Joke to tell user."""
|
"""Joke to tell user."""
|
||||||
|
|
||||||
setup: str = Field(description="question to set up a joke")
|
setup: str = FieldV1(description="question to set up a joke")
|
||||||
punchline: str = Field(description="answer to resolve the joke")
|
punchline: str = FieldV1(description="answer to resolve the joke")
|
||||||
|
|
||||||
# Pydantic class
|
# Pydantic class
|
||||||
chat = model.with_structured_output(Joke)
|
chat = model.with_structured_output(Joke)
|
||||||
@@ -439,7 +436,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
if not self.supports_anthropic_inputs:
|
if not self.supports_anthropic_inputs:
|
||||||
return
|
return
|
||||||
|
|
||||||
class color_picker(BaseModel):
|
class color_picker(BaseModelV1):
|
||||||
"""Input your fav color and get a random fact about it."""
|
"""Input your fav color and get a random fact about it."""
|
||||||
|
|
||||||
fav_color: str
|
fav_color: str
|
||||||
|
@@ -9,6 +9,15 @@ from langchain_core.language_models import BaseChatModel
|
|||||||
from langchain_core.runnables import RunnableBinding
|
from langchain_core.runnables import RunnableBinding
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
from pydantic.v1 import (
|
||||||
|
BaseModel as BaseModelV1,
|
||||||
|
)
|
||||||
|
from pydantic.v1 import (
|
||||||
|
Field as FieldV1,
|
||||||
|
)
|
||||||
|
from pydantic.v1 import (
|
||||||
|
ValidationError as ValidationErrorV1,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_standard_tests.base import BaseStandardTests
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
@@ -25,27 +34,24 @@ def generate_schema_pydantic_v1_from_2() -> Any:
|
|||||||
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||||
if PYDANTIC_MAJOR_VERSION != 2:
|
if PYDANTIC_MAJOR_VERSION != 2:
|
||||||
raise AssertionError("This function is only compatible with Pydantic v2.")
|
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||||
from pydantic.v1 import BaseModel, Field
|
|
||||||
|
|
||||||
class PersonB(BaseModel):
|
class PersonB(BaseModelV1):
|
||||||
"""Record attributes of a person."""
|
"""Record attributes of a person."""
|
||||||
|
|
||||||
name: str = Field(..., description="The name of the person.")
|
name: str = FieldV1(..., description="The name of the person.")
|
||||||
age: int = Field(..., description="The age of the person.")
|
age: int = FieldV1(..., description="The age of the person.")
|
||||||
|
|
||||||
return PersonB
|
return PersonB
|
||||||
|
|
||||||
|
|
||||||
def generate_schema_pydantic() -> Any:
|
def generate_schema_pydantic() -> Any:
|
||||||
"""Works with either pydantic 1 or 2"""
|
"""Works with either pydantic 1 or 2"""
|
||||||
from pydantic import BaseModel as BaseModelProper
|
|
||||||
from pydantic import Field as FieldProper
|
|
||||||
|
|
||||||
class PersonA(BaseModelProper):
|
class PersonA(BaseModel):
|
||||||
"""Record attributes of a person."""
|
"""Record attributes of a person."""
|
||||||
|
|
||||||
name: str = FieldProper(..., description="The name of the person.")
|
name: str = Field(..., description="The name of the person.")
|
||||||
age: int = FieldProper(..., description="The age of the person.")
|
age: int = Field(..., description="The age of the person.")
|
||||||
|
|
||||||
return PersonA
|
return PersonA
|
||||||
|
|
||||||
@@ -199,9 +205,7 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
assert model.with_structured_output(schema) is not None
|
assert model.with_structured_output(schema) is not None
|
||||||
|
|
||||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
class ExpectedParams(BaseModelV1):
|
||||||
|
|
||||||
class ExpectedParams(BaseModel):
|
|
||||||
ls_provider: str
|
ls_provider: str
|
||||||
ls_model_name: str
|
ls_model_name: str
|
||||||
ls_model_type: Literal["chat"]
|
ls_model_type: Literal["chat"]
|
||||||
@@ -212,7 +216,7 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
ls_params = model._get_ls_params()
|
ls_params = model._get_ls_params()
|
||||||
try:
|
try:
|
||||||
ExpectedParams(**ls_params)
|
ExpectedParams(**ls_params)
|
||||||
except ValidationError as e:
|
except ValidationErrorV1 as e:
|
||||||
pytest.fail(f"Validation error: {e}")
|
pytest.fail(f"Validation error: {e}")
|
||||||
|
|
||||||
# Test optional params
|
# Test optional params
|
||||||
@@ -222,5 +226,5 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
ls_params = model._get_ls_params()
|
ls_params = model._get_ls_params()
|
||||||
try:
|
try:
|
||||||
ExpectedParams(**ls_params)
|
ExpectedParams(**ls_params)
|
||||||
except ValidationError as e:
|
except ValidationErrorV1 as e:
|
||||||
pytest.fail(f"Validation error: {e}")
|
pytest.fail(f"Validation error: {e}")
|
||||||
|
Reference in New Issue
Block a user