standard-tests[patch]: Update chat model standard tests (#22378)

- Refactor standard test classes to make them easier to configure
- Update openai to support stop_sequences init param
- Update groq to support stop_sequences init param
- Update fireworks to support max_retries init param
- Update ChatModel.bind_tools to type tool_choice
- Update groq to handle tool_choice="any". **this may be controversial**

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur 2024-06-17 13:37:41 -07:00 committed by GitHub
parent 14f0cdad58
commit d96f67b06f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 383 additions and 378 deletions

View File

@ -10,98 +10,38 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ai21 import ChatAI21 from langchain_ai21 import ChatAI21
class TestAI21J2(ChatModelIntegrationTests): class BaseTestAI21(ChatModelIntegrationTests):
def teardown(self) -> None: def teardown(self) -> None:
# avoid getting rate limited # avoid getting rate limited
time.sleep(1) time.sleep(1)
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21 return ChatAI21
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream( def test_stream(self, model: BaseChatModel) -> None:
self, super().test_stream(model)
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_stream(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
async def test_astream( async def test_astream(self, model: BaseChatModel) -> None:
self, await super().test_astream(model)
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
await super().test_astream(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Not implemented.") @pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata( def test_usage_metadata(self, model: BaseChatModel) -> None:
self, super().test_usage_metadata(model)
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.fixture
class TestAI21J2(BaseTestAI21):
@property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "j2-ultra", "model": "j2-ultra",
} }
class TestAI21Jamba(ChatModelIntegrationTests): class TestAI21Jamba(BaseTestAI21):
def teardown(self) -> None: @property
# avoid getting rate limited
time.sleep(1)
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_stream(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
async def test_astream(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
await super().test_astream(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.fixture
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "jamba-instruct-preview", "model": "jamba-instruct-preview",

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,11 +9,11 @@ from langchain_ai21 import ChatAI21
class TestAI21J2(ChatModelUnitTests): class TestAI21J2(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21 return ChatAI21
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "j2-ultra", "model": "j2-ultra",
@ -23,11 +22,11 @@ class TestAI21J2(ChatModelUnitTests):
class TestAI21Jamba(ChatModelUnitTests): class TestAI21Jamba(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21 return ChatAI21
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "jamba-instruct", "model": "jamba-instruct",

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelIntegrationTests): class TestAnthropicStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAnthropic return ChatAnthropic
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "claude-3-haiku-20240307"}
"model": "claude-3-haiku-20240307",
}

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic
class TestAnthropicStandard(ChatModelUnitTests): class TestAnthropicStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAnthropic return ChatAnthropic
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "claude-3-haiku-20240307"}
"model": "claude-3-haiku-20240307",
}

View File

@ -296,6 +296,8 @@ class ChatFireworks(BaseChatModel):
"""Model name to use.""" """Model name to use."""
temperature: float = 0.0 temperature: float = 0.0
"""What sampling temperature to use.""" """What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key") fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
@ -314,8 +316,8 @@ class ChatFireworks(BaseChatModel):
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences") max_retries: Optional[int] = None
"""Default stop sequences.""" """Maximum number of retries to make when generating."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -360,6 +362,9 @@ class ChatFireworks(BaseChatModel):
values["client"] = Fireworks(**client_params).chat.completions values["client"] = Fireworks(**client_params).chat.completions
if not values.get("async_client"): if not values.get("async_client"):
values["async_client"] = AsyncFireworks(**client_params).chat.completions values["async_client"] = AsyncFireworks(**client_params).chat.completions
if values["max_retries"]:
values["client"]._max_retries = values["max_retries"]
values["async_client"]._max_retries = values["max_retries"]
return values return values
@property @property

View File

@ -10,11 +10,11 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelIntegrationTests): class TestFireworksStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatFireworks return ChatFireworks
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "accounts/fireworks/models/firefunction-v1", "model": "accounts/fireworks/models/firefunction-v1",
@ -22,12 +22,5 @@ class TestFireworksStandard(ChatModelIntegrationTests):
} }
@pytest.mark.xfail(reason="Not yet implemented.") @pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_list_content( def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
self, super().test_tool_message_histories_list_content(model)
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
super().test_tool_message_histories_list_content(
chat_model_class, chat_model_params, chat_model_has_tool_calling
)

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,12 +9,10 @@ from langchain_fireworks import ChatFireworks
class TestFireworksStandard(ChatModelUnitTests): class TestFireworksStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatFireworks return ChatFireworks
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"api_key": "test_api_key"}
"api_key": "test_api_key",
}

View File

@ -304,6 +304,8 @@ class ChatGroq(BaseChatModel):
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 temperature: float = 0.7
"""What sampling temperature to use.""" """What sampling temperature to use."""
stop: Optional[Union[List[str], str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
@ -326,8 +328,6 @@ class ChatGroq(BaseChatModel):
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
default_headers: Union[Mapping[str, str], None] = None default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
@ -449,7 +449,7 @@ class ChatGroq(BaseChatModel):
if ls_max_tokens := params.get("max_tokens", self.max_tokens): if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None) or self.stop: if ls_stop := stop or params.get("stop", None) or self.stop:
ls_params["ls_stop"] = ls_stop ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
return ls_params return ls_params
def _generate( def _generate(
@ -804,10 +804,19 @@ class ChatGroq(BaseChatModel):
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice: if tool_choice is not None and tool_choice:
if tool_choice == "any":
if len(tools) > 1:
raise ValueError(
f"Groq does not currently support {tool_choice=}. Should "
f"be one of 'auto', 'none', or the name of the tool to call."
)
else:
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
if isinstance(tool_choice, str) and ( if isinstance(tool_choice, str) and (
tool_choice not in ("auto", "any", "none") tool_choice not in ("auto", "any", "none")
): ):
tool_choice = {"type": "function", "function": {"name": tool_choice}} tool_choice = {"type": "function", "function": {"name": tool_choice}}
# TODO: Remove this update once 'any' is supported.
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1): if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
raise ValueError( raise ValueError(
"When specifying `tool_choice`, you must provide exactly one " "When specifying `tool_choice`, you must provide exactly one "

View File

@ -92,5 +92,6 @@ filterwarnings = [
'ignore:The method `ChatGroq.with_structured_output` is in beta', 'ignore:The method `ChatGroq.with_structured_output` is in beta',
# Maintain support for pydantic 1.X # Maintain support for pydantic 1.X
'default:The `dict` method is deprecated; use `model_dump` instead:DeprecationWarning', 'default:The `dict` method is deprecated; use `model_dump` instead:DeprecationWarning',
"ignore:tool_choice='any' is not currently supported. Converting to 'auto'.",
] ]
asyncio_mode = "auto" asyncio_mode = "auto"

View File

@ -10,17 +10,10 @@ from langchain_groq import ChatGroq
class TestGroqStandard(ChatModelIntegrationTests): class TestGroqStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq return ChatGroq
@pytest.mark.xfail(reason="Not yet implemented.") @pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_list_content( def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
self, super().test_tool_message_histories_list_content(model)
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
super().test_tool_message_histories_list_content(
chat_model_class, chat_model_params, chat_model_has_tool_calling
)

View File

@ -2,14 +2,26 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_core.runnables import RunnableBinding
from langchain_standard_tests.unit_tests.chat_models import (
ChatModelUnitTests,
Person,
my_adder_tool,
)
from langchain_groq import ChatGroq from langchain_groq import ChatGroq
class TestGroqStandard(ChatModelUnitTests): class TestGroqStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq return ChatGroq
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
"""Does not currently support tool_choice='any'."""
if not self.has_tool_calling:
return
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
assert isinstance(tool_model, RunnableBinding)

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
@ -10,13 +9,10 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelIntegrationTests): class TestMistralStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatMistralAI return ChatMistralAI
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "mistral-large-latest", "temperature": 0}
"model": "mistral-large-latest",
"temperature": 0,
}

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,6 +9,6 @@ from langchain_mistralai import ChatMistralAI
class TestMistralStandard(ChatModelUnitTests): class TestMistralStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatMistralAI return ChatMistralAI

View File

@ -3,13 +3,18 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
import openai import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai.chat_models.base import BaseChatOpenAI from langchain_openai.chat_models.base import BaseChatOpenAI
@ -210,6 +215,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
).chat.completions ).chat.completions
return values return values
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
# As of 05/2024 Azure OpenAI doesn't support tool_choice="required".
# TODO: Update this condition once tool_choice="required" is supported.
if tool_choice in ("any", "required", True):
if len(tools) > 1:
raise ValueError(
f"Azure OpenAI does not currently support {tool_choice=}. Should "
f"be one of 'auto', 'none', or the name of the tool to call."
)
else:
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""

View File

@ -345,6 +345,8 @@ class BaseChatOpenAI(BaseChatModel):
http_async_client: Union[Any, None] = None http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify """Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations.""" http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -441,6 +443,7 @@ class BaseChatOpenAI(BaseChatModel):
"stream": self.streaming, "stream": self.streaming,
"n": self.n, "n": self.n,
"temperature": self.temperature, "temperature": self.temperature,
"stop": self.stop,
**self.model_kwargs, **self.model_kwargs,
} }
if self.max_tokens is not None: if self.max_tokens is not None:
@ -548,8 +551,6 @@ class BaseChatOpenAI(BaseChatModel):
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params params = self._default_params
if stop is not None: if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params return message_dicts, params
@ -871,15 +872,7 @@ class BaseChatOpenAI(BaseChatModel):
if tool_choice == "any": if tool_choice == "any":
tool_choice = "required" tool_choice = "required"
elif isinstance(tool_choice, bool): elif isinstance(tool_choice, bool):
if len(tools) > 1: tool_choice = "required"
raise ValueError(
"tool_choice=True can only be used when a single tool is "
f"passed in, received {len(tools)} tools."
)
tool_choice = {
"type": "function",
"function": {"name": formatted_tools[0]["function"]["name"]},
}
elif isinstance(tool_choice, dict): elif isinstance(tool_choice, dict):
tool_names = [ tool_names = [
formatted_tool["function"]["name"] formatted_tool["function"]["name"]
@ -1094,7 +1087,7 @@ class BaseChatOpenAI(BaseChatModel):
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "
"Received None." "Received None."
) )
llm = self.bind_tools([schema], tool_choice=True) llm = self.bind_tools([schema], tool_choice="any")
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True

View File

@ -440,8 +440,6 @@ class BaseOpenAI(BaseLLM):
) -> List[List[str]]: ) -> List[List[str]]:
"""Get the sub prompts for llm call.""" """Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params["max_tokens"] == -1: if params["max_tokens"] == -1:
if len(prompts) != 1: if len(prompts) != 1:

View File

@ -3,7 +3,6 @@
import os import os
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
@ -19,15 +18,15 @@ DEPLOYMENT_NAME = os.environ.get(
class TestOpenAIStandard(ChatModelIntegrationTests): class TestOpenAIStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return AzureChatOpenAI return AzureChatOpenAI
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"deployment_name": DEPLOYMENT_NAME, "deployment_name": DEPLOYMENT_NAME,
"openai_api_version": OPENAI_API_VERSION, "openai_api_version": OPENAI_API_VERSION,
"azure_endpoint": OPENAI_API_BASE, "azure_endpoint": OPENAI_API_BASE,
"openai_api_key": OPENAI_API_KEY, "api_key": OPENAI_API_KEY,
} }

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI
class TestOpenAIStandard(ChatModelIntegrationTests): class TestOpenAIStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI

View File

@ -99,13 +99,6 @@ def test_openai_stop_valid() -> None:
assert first_output == second_output assert first_output == second_output
def test_openai_stop_error() -> None:
"""Test openai stop logic on bad configuration."""
llm = OpenAI(stop="3", temperature=0)
with pytest.raises(ValueError):
llm.invoke("write an ordered list of five items", stop=["\n"])
@pytest.mark.scheduled @pytest.mark.scheduled
def test_openai_streaming() -> None: def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""

View File

@ -2,23 +2,31 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import RunnableBinding
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests): class TestOpenAIStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return AzureChatOpenAI return AzureChatOpenAI
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"deployment_name": "test", "deployment_name": "test",
"openai_api_version": "2021-10-01", "openai_api_version": "2021-10-01",
"azure_endpoint": "https://test.azure.com", "azure_endpoint": "https://test.azure.com",
"openai_api_key": "test",
} }
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
"""Does not currently support tool_choice='any'."""
if not self.has_tool_calling:
return
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
assert isinstance(tool_model, RunnableBinding)

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI
class TestOpenAIStandard(ChatModelUnitTests): class TestOpenAIStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI return ChatOpenAI

View File

@ -566,7 +566,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.4" version = "0.2.7"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -575,15 +575,12 @@ develop = true
[package.dependencies] [package.dependencies]
jsonpatch = "^1.33" jsonpatch = "^1.33"
langsmith = "^0.1.66" langsmith = "^0.1.75"
packaging = "^23.2" packaging = ">=23.2,<25"
pydantic = ">=1,<3" pydantic = ">=1,<3"
PyYAML = ">=5.3" PyYAML = ">=5.3"
tenacity = "^8.1.0" tenacity = "^8.1.0"
[package.extras]
extended-testing = ["jinja2 (>=3,<4)"]
[package.source] [package.source]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
@ -608,7 +605,7 @@ url = "../openai"
[[package]] [[package]]
name = "langchain-standard-tests" name = "langchain-standard-tests"
version = "0.1.0" version = "0.1.1"
description = "Standard tests for LangChain implementations" description = "Standard tests for LangChain implementations"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -625,13 +622,13 @@ url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.69" version = "0.1.77"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"}, {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"},
{file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"}, {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"},
] ]
[package.dependencies] [package.dependencies]
@ -871,6 +868,51 @@ files = [
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
] ]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.30.1" version = "1.30.1"
@ -1679,4 +1721,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "e7b21f556475be4c7133b74b6b0e138012bef9d47bc5bdc9709b24e55d9500f0" content-hash = "8a868382f8f3b693dccc1ce99428cdf9d6f8b6f77b3403c342c2bcc7b8526db9"

View File

@ -43,6 +43,12 @@ codespell = "^2.2.0"
optional = true optional = true
[tool.poetry.group.test_integration.dependencies] [tool.poetry.group.test_integration.dependencies]
# Support Python 3.8 and 3.12+.
numpy = [
{version = "^1", python = "<3.12"},
{version = "^1.26.0", python = ">=3.12"}
]
[tool.poetry.group.lint] [tool.poetry.group.lint]
optional = true optional = true

View File

@ -2,20 +2,17 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_together import ChatTogether from langchain_together import ChatTogether
class TestTogethertandard(ChatModelIntegrationTests): class TestTogetherStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatTogether return ChatTogether
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "mistralai/Mistral-7B-Instruct-v0.1"}
"model": "mistralai/Mistral-7B-Instruct-v0.1",
}

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,12 +9,10 @@ from langchain_together import ChatTogether
class TestTogetherStandard(ChatModelUnitTests): class TestTogetherStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatTogether return ChatTogether
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "meta-llama/Llama-3-8b-chat-hf"}
"model": "meta-llama/Llama-3-8b-chat-hf",
}

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "anyio" name = "anyio"
@ -340,7 +340,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.4" version = "0.2.7"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -349,15 +349,12 @@ develop = true
[package.dependencies] [package.dependencies]
jsonpatch = "^1.33" jsonpatch = "^1.33"
langsmith = "^0.1.66" langsmith = "^0.1.75"
packaging = "^23.2" packaging = ">=23.2,<25"
pydantic = ">=1,<3" pydantic = ">=1,<3"
PyYAML = ">=5.3" PyYAML = ">=5.3"
tenacity = "^8.1.0" tenacity = "^8.1.0"
[package.extras]
extended-testing = ["jinja2 (>=3,<4)"]
[package.source] [package.source]
type = "directory" type = "directory"
url = "../../core" url = "../../core"
@ -382,7 +379,7 @@ url = "../openai"
[[package]] [[package]]
name = "langchain-standard-tests" name = "langchain-standard-tests"
version = "0.1.0" version = "0.1.1"
description = "Standard tests for LangChain implementations" description = "Standard tests for LangChain implementations"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -399,13 +396,13 @@ url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.69" version = "0.1.77"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"}, {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"},
{file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"}, {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"},
] ]
[package.dependencies] [package.dependencies]
@ -546,6 +543,51 @@ files = [
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
] ]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.30.1" version = "1.30.1"
@ -725,26 +767,31 @@ python-versions = ">=3.8"
files = [ files = [
{file = "PyMuPDF-1.24.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:8d63d850d337c10fa49859697b9517e461b28e6d5d5a80121c72cc518eb0bae0"}, {file = "PyMuPDF-1.24.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:8d63d850d337c10fa49859697b9517e461b28e6d5d5a80121c72cc518eb0bae0"},
{file = "PyMuPDF-1.24.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:5f4a9ffabbcf8f19f6938484702e393ed6d423516f3e52c9d443162e3e42a884"}, {file = "PyMuPDF-1.24.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:5f4a9ffabbcf8f19f6938484702e393ed6d423516f3e52c9d443162e3e42a884"},
{file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:c7dfddf19d2a8c734c5439692e87419c86f2621f1f205100355afb3bb43e5675"},
{file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:cf743d8c7f7261112153525ba7de1d954f9d563b875414814b27da35fb0df2cc"}, {file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:cf743d8c7f7261112153525ba7de1d954f9d563b875414814b27da35fb0df2cc"},
{file = "PyMuPDF-1.24.3-cp310-none-win32.whl", hash = "sha256:e30e8dec04c241739e0e9cf89b8a0317e991889dbca781e30abef228009c8cbd"}, {file = "PyMuPDF-1.24.3-cp310-none-win32.whl", hash = "sha256:e30e8dec04c241739e0e9cf89b8a0317e991889dbca781e30abef228009c8cbd"},
{file = "PyMuPDF-1.24.3-cp310-none-win_amd64.whl", hash = "sha256:3ceca02b143efe6b6f159d64a2f0e0aa32d0670894149a7f7144125fe2982da2"}, {file = "PyMuPDF-1.24.3-cp310-none-win_amd64.whl", hash = "sha256:3ceca02b143efe6b6f159d64a2f0e0aa32d0670894149a7f7144125fe2982da2"},
{file = "PyMuPDF-1.24.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:171313ee03e031437687cf856914eb61f66a5a76eddedc63048a63b69b00474b"}, {file = "PyMuPDF-1.24.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:171313ee03e031437687cf856914eb61f66a5a76eddedc63048a63b69b00474b"},
{file = "PyMuPDF-1.24.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a421b332c257e70d9daed350cebefc043817ae0fd6b361734ee27f98288cc8c7"}, {file = "PyMuPDF-1.24.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a421b332c257e70d9daed350cebefc043817ae0fd6b361734ee27f98288cc8c7"},
{file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:cc519230e352024111f065a1d32366eea4f1f1034e01515f10dbed709d9ab5ad"},
{file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:9df2d5e89810d3708fb8933dbc07505a57bfcb976a72bc559c7f0ccacc054c76"}, {file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:9df2d5e89810d3708fb8933dbc07505a57bfcb976a72bc559c7f0ccacc054c76"},
{file = "PyMuPDF-1.24.3-cp311-none-win32.whl", hash = "sha256:1de61f186c8367d1647d679bf6a4a77198751b378f9b67958a3b5d59adbc8c95"}, {file = "PyMuPDF-1.24.3-cp311-none-win32.whl", hash = "sha256:1de61f186c8367d1647d679bf6a4a77198751b378f9b67958a3b5d59adbc8c95"},
{file = "PyMuPDF-1.24.3-cp311-none-win_amd64.whl", hash = "sha256:28e8c6c29de2951e29f98f17752eff0e80776fca7fe7ed5c7368363dff887c6c"}, {file = "PyMuPDF-1.24.3-cp311-none-win_amd64.whl", hash = "sha256:28e8c6c29de2951e29f98f17752eff0e80776fca7fe7ed5c7368363dff887c6c"},
{file = "PyMuPDF-1.24.3-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:34ab87e6d0f79eea9b632ed0401de20aff2622c95aa1a57fd17b49401c22c906"}, {file = "PyMuPDF-1.24.3-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:34ab87e6d0f79eea9b632ed0401de20aff2622c95aa1a57fd17b49401c22c906"},
{file = "PyMuPDF-1.24.3-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ef2311861a3173072c489dc365827bb26f2c4487f969501afbbf1746478553ea"}, {file = "PyMuPDF-1.24.3-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ef2311861a3173072c489dc365827bb26f2c4487f969501afbbf1746478553ea"},
{file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:c4df2c50eba8fb8d8ffe63bd4099c57b562d11ed01dcf6cd922c4ea774212a34"},
{file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:401f2da8621f19bc302efa2a404c794b17982dea0e552b48ecd2c3f8d10b4707"}, {file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:401f2da8621f19bc302efa2a404c794b17982dea0e552b48ecd2c3f8d10b4707"},
{file = "PyMuPDF-1.24.3-cp312-none-win32.whl", hash = "sha256:ce4c07355b45e95803d1221cece01be58e32d1d9daec0d1ebc075ad03640c177"}, {file = "PyMuPDF-1.24.3-cp312-none-win32.whl", hash = "sha256:ce4c07355b45e95803d1221cece01be58e32d1d9daec0d1ebc075ad03640c177"},
{file = "PyMuPDF-1.24.3-cp312-none-win_amd64.whl", hash = "sha256:4f084f735e2e2d21f2c76de1abdcb44261889ec01a2842b57e69c89502f74b7a"}, {file = "PyMuPDF-1.24.3-cp312-none-win_amd64.whl", hash = "sha256:4f084f735e2e2d21f2c76de1abdcb44261889ec01a2842b57e69c89502f74b7a"},
{file = "PyMuPDF-1.24.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:14b2459e1a7e4dbf9ec6026e6056ccba6868bdfff1ffb346fd910108a61be095"}, {file = "PyMuPDF-1.24.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:14b2459e1a7e4dbf9ec6026e6056ccba6868bdfff1ffb346fd910108a61be095"},
{file = "PyMuPDF-1.24.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:f3572c2a85a12026637d485d6156b7f279a4aac7f474a341e5e06e8943ab2e0b"}, {file = "PyMuPDF-1.24.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:f3572c2a85a12026637d485d6156b7f279a4aac7f474a341e5e06e8943ab2e0b"},
{file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:b0ceed71fa62eebd1bf8b55875cd6da7c2f09bbe2067218b68b5deb0d9feaa6e"},
{file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:e6b9100fa5194be1240b9998643ba122fcffd76149dccda3607455ccfed5fa2b"}, {file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:e6b9100fa5194be1240b9998643ba122fcffd76149dccda3607455ccfed5fa2b"},
{file = "PyMuPDF-1.24.3-cp38-none-win32.whl", hash = "sha256:88e52a5c6d0375d27401c08fe7f7894f19db4af31169ba6deb6b3c1453f8b6e0"}, {file = "PyMuPDF-1.24.3-cp38-none-win32.whl", hash = "sha256:88e52a5c6d0375d27401c08fe7f7894f19db4af31169ba6deb6b3c1453f8b6e0"},
{file = "PyMuPDF-1.24.3-cp38-none-win_amd64.whl", hash = "sha256:45c93944a14b19da3ee9b6d648e609f3ca35b8bca5c1cd16e6addcc59e7816d9"}, {file = "PyMuPDF-1.24.3-cp38-none-win_amd64.whl", hash = "sha256:45c93944a14b19da3ee9b6d648e609f3ca35b8bca5c1cd16e6addcc59e7816d9"},
{file = "PyMuPDF-1.24.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0d4b6caf5ad25b7bd654ad4d42b8b3a00683b742bc5a81b8aeface79811386d5"}, {file = "PyMuPDF-1.24.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0d4b6caf5ad25b7bd654ad4d42b8b3a00683b742bc5a81b8aeface79811386d5"},
{file = "PyMuPDF-1.24.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:6b52ee0460db88c71a06677353a0c768a8bb17718aa313462e9847ed1bf53f87"}, {file = "PyMuPDF-1.24.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:6b52ee0460db88c71a06677353a0c768a8bb17718aa313462e9847ed1bf53f87"},
{file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:06a8a3226c9ec97c5e1df8cd16ec29b5df83d04ae88e9e0f5f4e25fcc1b997a1"},
{file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:ce113303b41adb74ae30ebd98761d9bd53477573e47566f05b3b7ff1c7354675"}, {file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:ce113303b41adb74ae30ebd98761d9bd53477573e47566f05b3b7ff1c7354675"},
{file = "PyMuPDF-1.24.3-cp39-none-win32.whl", hash = "sha256:e4b4b2d5700c48a67da278476767488005408fac29426467b5bb437012197c0b"}, {file = "PyMuPDF-1.24.3-cp39-none-win32.whl", hash = "sha256:e4b4b2d5700c48a67da278476767488005408fac29426467b5bb437012197c0b"},
{file = "PyMuPDF-1.24.3-cp39-none-win_amd64.whl", hash = "sha256:39acbac2854ef5b58f28c71bb19e84840771a771ec09cb33c4e66e2679c3b419"}, {file = "PyMuPDF-1.24.3-cp39-none-win_amd64.whl", hash = "sha256:39acbac2854ef5b58f28c71bb19e84840771a771ec09cb33c4e66e2679c3b419"},
@ -763,6 +810,7 @@ python-versions = ">=3.8"
files = [ files = [
{file = "PyMuPDFb-1.24.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d2ccca660042896d4af479f979ec10674c5a0b3cd2d9ecb0011f08dc82380cce"}, {file = "PyMuPDFb-1.24.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d2ccca660042896d4af479f979ec10674c5a0b3cd2d9ecb0011f08dc82380cce"},
{file = "PyMuPDFb-1.24.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ad51d21086a16199684a3eebcb47d9c8460fc27e7bebae77f5fe64e8c34ebf34"}, {file = "PyMuPDFb-1.24.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ad51d21086a16199684a3eebcb47d9c8460fc27e7bebae77f5fe64e8c34ebf34"},
{file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e7aab000d707c40e3254cd60152897b90952ed9a3567584d70974292f4912ce"},
{file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f39588fd2b7a63e2456df42cd8925c316202e0eb77d115d9c01ba032b2c9086f"}, {file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f39588fd2b7a63e2456df42cd8925c316202e0eb77d115d9c01ba032b2c9086f"},
{file = "PyMuPDFb-1.24.3-py3-none-win32.whl", hash = "sha256:0d606a10cb828cefc9f864bf67bc9d46e8007af55e643f022b59d378af4151a8"}, {file = "PyMuPDFb-1.24.3-py3-none-win32.whl", hash = "sha256:0d606a10cb828cefc9f864bf67bc9d46e8007af55e643f022b59d378af4151a8"},
{file = "PyMuPDFb-1.24.3-py3-none-win_amd64.whl", hash = "sha256:e88289bd4b4afe5966a028774b302f37d4b51dad5c5e6720dd04524910db6c6e"}, {file = "PyMuPDFb-1.24.3-py3-none-win_amd64.whl", hash = "sha256:e88289bd4b4afe5966a028774b302f37d4b51dad5c5e6720dd04524910db6c6e"},
@ -1304,4 +1352,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "0073172ce2312055480e9ff47dc99ce7dfd6809208ad5ea4cee5ecf7f12eef56" content-hash = "b21648a1fdc08f901c82fb3b4773682f0a4b83b03b97ae1ddbd0834b730ff8c2"

View File

@ -43,6 +43,12 @@ codespell = "^2.2.0"
optional = true optional = true
[tool.poetry.group.test_integration.dependencies] [tool.poetry.group.test_integration.dependencies]
# Support Python 3.8 and 3.12+.
numpy = [
{version = "^1", python = "<3.12"},
{version = "^1.26.0", python = ">=3.12"}
]
[tool.poetry.group.lint] [tool.poetry.group.lint]
optional = true optional = true

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage
class TestUpstageStandard(ChatModelIntegrationTests): class TestUpstageStandard(ChatModelIntegrationTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatUpstage return ChatUpstage
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "solar-1-mini-chat"}
"model": "solar-1-mini-chat",
}

View File

@ -2,7 +2,6 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage
class TestUpstageStandard(ChatModelUnitTests): class TestUpstageStandard(ChatModelUnitTests):
@pytest.fixture @property
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatUpstage return ChatUpstage
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {"model": "solar-1-mini-chat"}
"model": "solar-1-mini-chat",
}

View File

@ -1,74 +1,31 @@
import json import json
from abc import ABC, abstractmethod
from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests,
my_adder_tool,
)
class Person(BaseModel): class ChatModelIntegrationTests(ChatModelTests):
name: str = Field(..., description="The name of the person.") def test_invoke(self, model: BaseChatModel) -> None:
age: int = Field(..., description="The age of the person.")
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
return a + b
class ChatModelIntegrationTests(ABC):
@abstractmethod
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
...
@pytest.fixture
def chat_model_params(self) -> dict:
return {}
@pytest.fixture
def chat_model_has_tool_calling(
self, chat_model_class: Type[BaseChatModel]
) -> bool:
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
@pytest.fixture
def chat_model_has_structured_output(
self, chat_model_class: Type[BaseChatModel]
) -> bool:
return (
chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output
)
def test_invoke(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
result = model.invoke("Hello") result = model.invoke("Hello")
assert result is not None assert result is not None
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
async def test_ainvoke( async def test_ainvoke(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
result = await model.ainvoke("Hello") result = await model.ainvoke("Hello")
assert result is not None assert result is not None
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
def test_stream( def test_stream(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
num_tokens = 0 num_tokens = 0
for token in model.stream("Hello"): for token in model.stream("Hello"):
assert token is not None assert token is not None
@ -76,10 +33,7 @@ class ChatModelIntegrationTests(ABC):
num_tokens += len(token.content) num_tokens += len(token.content)
assert num_tokens > 0 assert num_tokens > 0
async def test_astream( async def test_astream(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
num_tokens = 0 num_tokens = 0
async for token in model.astream("Hello"): async for token in model.astream("Hello"):
assert token is not None assert token is not None
@ -87,10 +41,7 @@ class ChatModelIntegrationTests(ABC):
num_tokens += len(token.content) num_tokens += len(token.content)
assert num_tokens > 0 assert num_tokens > 0
def test_batch( def test_batch(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
batch_results = model.batch(["Hello", "Hey"]) batch_results = model.batch(["Hello", "Hey"])
assert batch_results is not None assert batch_results is not None
assert isinstance(batch_results, list) assert isinstance(batch_results, list)
@ -101,10 +52,7 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
async def test_abatch( async def test_abatch(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
batch_results = await model.abatch(["Hello", "Hey"]) batch_results = await model.abatch(["Hello", "Hey"])
assert batch_results is not None assert batch_results is not None
assert isinstance(batch_results, list) assert isinstance(batch_results, list)
@ -115,14 +63,11 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
def test_conversation( def test_conversation(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
messages = [ messages = [
HumanMessage(content="hello"), HumanMessage("hello"),
AIMessage(content="hello"), AIMessage("hello"),
HumanMessage(content="how are you"), HumanMessage("how are you"),
] ]
result = model.invoke(messages) result = model.invoke(messages)
assert result is not None assert result is not None
@ -130,10 +75,9 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
def test_usage_metadata( def test_usage_metadata(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict if not self.returns_usage_metadata:
) -> None: pytest.skip("Not implemented.")
model = chat_model_class(**chat_model_params)
result = model.invoke("Hello") result = model.invoke("Hello")
assert result is not None assert result is not None
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
@ -142,39 +86,35 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int)
def test_stop_sequence( def test_stop_sequence(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
result = model.invoke("hi", stop=["you"]) result = model.invoke("hi", stop=["you"])
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
model = chat_model_class(**chat_model_params, stop=["you"]) custom_model = self.chat_model_class(
result = model.invoke("hi") **{**self.chat_model_params, "stop": ["you"]}
)
result = custom_model.invoke("hi")
assert isinstance(result, AIMessage) assert isinstance(result, AIMessage)
def test_tool_message_histories_string_content( def test_tool_message_histories_string_content(
self, self,
chat_model_class: Type[BaseChatModel], model: BaseChatModel,
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None: ) -> None:
""" """
Test that message histories are compatible with string tool contents Test that message histories are compatible with string tool contents
(e.g. OpenAI). (e.g. OpenAI).
""" """
if not chat_model_has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
model_with_tools = model.bind_tools([my_adder_tool]) model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool" function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"} function_args = {"a": "1", "b": "2"}
messages_string_content = [ messages_string_content = [
HumanMessage(content="What is 1 + 2"), HumanMessage("What is 1 + 2"),
# string content (e.g. OpenAI) # string content (e.g. OpenAI)
AIMessage( AIMessage(
content="", "",
tool_calls=[ tool_calls=[
{ {
"name": function_name, "name": function_name,
@ -184,8 +124,8 @@ class ChatModelIntegrationTests(ABC):
], ],
), ),
ToolMessage( ToolMessage(
json.dumps({"result": 3}),
name=function_name, name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123", tool_call_id="abc123",
), ),
] ]
@ -194,26 +134,23 @@ class ChatModelIntegrationTests(ABC):
def test_tool_message_histories_list_content( def test_tool_message_histories_list_content(
self, self,
chat_model_class: Type[BaseChatModel], model: BaseChatModel,
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None: ) -> None:
""" """
Test that message histories are compatible with list tool contents Test that message histories are compatible with list tool contents
(e.g. Anthropic). (e.g. Anthropic).
""" """
if not chat_model_has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
model_with_tools = model.bind_tools([my_adder_tool]) model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool" function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2} function_args = {"a": 1, "b": 2}
messages_list_content = [ messages_list_content = [
HumanMessage(content="What is 1 + 2"), HumanMessage("What is 1 + 2"),
# List content (e.g., Anthropic) # List content (e.g., Anthropic)
AIMessage( AIMessage(
content=[ [
{"type": "text", "text": "some text"}, {"type": "text", "text": "some text"},
{ {
"type": "tool_use", "type": "tool_use",
@ -231,8 +168,8 @@ class ChatModelIntegrationTests(ABC):
], ],
), ),
ToolMessage( ToolMessage(
json.dumps({"result": 3}),
name=function_name, name=function_name,
content=json.dumps({"result": 3}),
tool_call_id="abc123", tool_call_id="abc123",
), ),
] ]
@ -241,25 +178,22 @@ class ChatModelIntegrationTests(ABC):
def test_structured_few_shot_examples( def test_structured_few_shot_examples(
self, self,
chat_model_class: Type[BaseChatModel], model: BaseChatModel,
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None: ) -> None:
""" """
Test that model can process few-shot examples with tool calls. Test that model can process few-shot examples with tool calls.
""" """
if not chat_model_has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params) model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any")
model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool" function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2} function_args = {"a": 1, "b": 2}
function_result = json.dumps({"result": 3}) function_result = json.dumps({"result": 3})
messages_string_content = [ messages_string_content = [
HumanMessage(content="What is 1 + 2"), HumanMessage("What is 1 + 2"),
AIMessage( AIMessage(
content="", "",
tool_calls=[ tool_calls=[
{ {
"name": function_name, "name": function_name,
@ -269,12 +203,12 @@ class ChatModelIntegrationTests(ABC):
], ],
), ),
ToolMessage( ToolMessage(
function_result,
name=function_name, name=function_name,
content=function_result,
tool_call_id="abc123", tool_call_id="abc123",
), ),
AIMessage(content=function_result), AIMessage(function_result),
HumanMessage(content="What is 3 + 4"), HumanMessage("What is 3 + 4"),
] ]
result_string_content = model_with_tools.invoke(messages_string_content) result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result_string_content, AIMessage) assert isinstance(result_string_content, AIMessage)

View File

@ -1,13 +1,16 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type from typing import Any, List, Literal, Optional, Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool
class Person(BaseModel): class Person(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.") name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.") age: int = Field(..., description="The age of the person.")
@ -18,81 +21,105 @@ def my_adder_tool(a: int, b: int) -> int:
return a + b return a + b
class ChatModelUnitTests(ABC): class ChatModelTests(ABC):
@property
@abstractmethod @abstractmethod
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
... ...
@pytest.fixture @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {} return {}
@pytest.fixture @property
def chat_model_has_tool_calling( def standard_chat_model_params(self) -> dict:
self, chat_model_class: Type[BaseChatModel] return {
) -> bool: "temperature": 0,
return chat_model_class.bind_tools is not BaseChatModel.bind_tools "max_tokens": 100,
"timeout": 60,
"stop_sequences": [],
"max_retries": 2,
}
@pytest.fixture @pytest.fixture
def chat_model_has_structured_output( def model(self) -> BaseChatModel:
self, chat_model_class: Type[BaseChatModel] return self.chat_model_class(
) -> bool: **{**self.standard_chat_model_params, **self.chat_model_params}
)
@property
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
@property
def has_structured_output(self) -> bool:
return ( return (
chat_model_class.with_structured_output self.chat_model_class.with_structured_output
is not BaseChatModel.with_structured_output is not BaseChatModel.with_structured_output
) )
def test_chat_model_init( @property
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict def supports_image_inputs(self) -> bool:
) -> None: return False
model = chat_model_class(**chat_model_params)
@property
def supports_video_inputs(self) -> bool:
return False
@property
def returns_usage_metadata(self) -> bool:
return True
class ChatModelUnitTests(ChatModelTests):
@property
def standard_chat_model_params(self) -> dict:
params = super().standard_chat_model_params
params["api_key"] = "test"
return params
def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None assert model is not None
def test_chat_model_init_api_key( def test_init_streaming(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
params = {**chat_model_params, "api_key": "test"}
model = chat_model_class(**params) # type: ignore
assert model is not None
def test_chat_model_init_streaming(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(streaming=True, **chat_model_params) # type: ignore
assert model is not None
def test_chat_model_bind_tool_pydantic(
self, self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None: ) -> None:
if not chat_model_has_tool_calling: model = self.chat_model_class(
**{
**self.standard_chat_model_params,
**self.chat_model_params,
"streaming": True,
}
)
assert model is not None
def test_bind_tool_pydantic(
self,
model: BaseChatModel,
) -> None:
if not self.has_tool_calling:
return return
model = chat_model_class(**chat_model_params) tool_model = model.bind_tools(
[Person, Person.schema(), my_adder_tool], tool_choice="any"
)
assert isinstance(tool_model, RunnableBinding)
assert hasattr(model, "bind_tools") @pytest.mark.parametrize("schema", [Person, Person.schema()])
tool_model = model.bind_tools([Person]) def test_with_structured_output(
assert tool_model is not None
def test_chat_model_with_structured_output(
self, self,
chat_model_class: Type[BaseChatModel], model: BaseChatModel,
chat_model_params: dict, schema: Any,
chat_model_has_structured_output: bool,
) -> None: ) -> None:
if not chat_model_has_structured_output: if not self.has_structured_output:
return return
model = chat_model_class(**chat_model_params) assert model.with_structured_output(schema) is not None
assert model is not None
assert model.with_structured_output(Person) is not None
def test_standard_params( def test_standard_params(self, model: BaseChatModel) -> None:
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
class ExpectedParams(BaseModel): class ExpectedParams(BaseModel):
ls_provider: str ls_provider: str
ls_model_name: str ls_model_name: str
@ -101,7 +128,6 @@ class ChatModelUnitTests(ABC):
ls_max_tokens: Optional[int] ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]] ls_stop: Optional[List[str]]
model = chat_model_class(**chat_model_params)
ls_params = model._get_ls_params() ls_params = model._get_ls_params()
try: try:
ExpectedParams(**ls_params) ExpectedParams(**ls_params)
@ -109,7 +135,9 @@ class ChatModelUnitTests(ABC):
pytest.fail(f"Validation error: {e}") pytest.fail(f"Validation error: {e}")
# Test optional params # Test optional params
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params) model = self.chat_model_class(
max_tokens=10, stop=["test"], **self.chat_model_params
)
ls_params = model._get_ls_params() ls_params = model._get_ls_params()
try: try:
ExpectedParams(**ls_params) ExpectedParams(**ls_params)