mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 22:04:37 +00:00
standard-tests[patch]: Update chat model standard tests (#22378)
- Refactor standard test classes to make them easier to configure - Update openai to support stop_sequences init param - Update groq to support stop_sequences init param - Update fireworks to support max_retries init param - Update ChatModel.bind_tools to type tool_choice - Update groq to handle tool_choice="any". **this may be controversial** --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
14f0cdad58
commit
d96f67b06f
@ -10,98 +10,38 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
from langchain_ai21 import ChatAI21
|
||||
|
||||
|
||||
class TestAI21J2(ChatModelIntegrationTests):
|
||||
class BaseTestAI21(ChatModelIntegrationTests):
|
||||
def teardown(self) -> None:
|
||||
# avoid getting rate limited
|
||||
time.sleep(1)
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
def test_stream(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_stream(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
super().test_stream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
async def test_astream(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
await super().test_astream(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
await super().test_astream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata(model)
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
class TestAI21J2(BaseTestAI21):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "j2-ultra",
|
||||
}
|
||||
|
||||
|
||||
class TestAI21Jamba(ChatModelIntegrationTests):
|
||||
def teardown(self) -> None:
|
||||
# avoid getting rate limited
|
||||
time.sleep(1)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
def test_stream(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_stream(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
async def test_astream(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
await super().test_astream(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
class TestAI21Jamba(BaseTestAI21):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-instruct-preview",
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,11 +9,11 @@ from langchain_ai21 import ChatAI21
|
||||
|
||||
|
||||
class TestAI21J2(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "j2-ultra",
|
||||
@ -23,11 +22,11 @@ class TestAI21J2(ChatModelUnitTests):
|
||||
|
||||
|
||||
class TestAI21Jamba(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-instruct",
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAnthropic
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
}
|
||||
return {"model": "claude-3-haiku-20240307"}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class TestAnthropicStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAnthropic
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
}
|
||||
return {"model": "claude-3-haiku-20240307"}
|
||||
|
@ -296,6 +296,8 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, List[str]]] = Field(None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
@ -314,8 +316,8 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -360,6 +362,9 @@ class ChatFireworks(BaseChatModel):
|
||||
values["client"] = Fireworks(**client_params).chat.completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = AsyncFireworks(**client_params).chat.completions
|
||||
if values["max_retries"]:
|
||||
values["client"]._max_retries = values["max_retries"]
|
||||
values["async_client"]._max_retries = values["max_retries"]
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -10,11 +10,11 @@ from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
||||
class TestFireworksStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatFireworks
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "accounts/fireworks/models/firefunction-v1",
|
||||
@ -22,12 +22,5 @@ class TestFireworksStandard(ChatModelIntegrationTests):
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
) -> None:
|
||||
super().test_tool_message_histories_list_content(
|
||||
chat_model_class, chat_model_params, chat_model_has_tool_calling
|
||||
)
|
||||
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_histories_list_content(model)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
||||
class TestFireworksStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatFireworks
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
return {"api_key": "test_api_key"}
|
||||
|
@ -304,6 +304,8 @@ class ChatGroq(BaseChatModel):
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[List[str], str]] = Field(None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
@ -326,8 +328,6 @@ class ChatGroq(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
@ -449,7 +449,7 @@ class ChatGroq(BaseChatModel):
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None) or self.stop:
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
|
||||
return ls_params
|
||||
|
||||
def _generate(
|
||||
@ -804,10 +804,19 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if tool_choice == "any":
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
f"Groq does not currently support {tool_choice=}. Should "
|
||||
f"be one of 'auto', 'none', or the name of the tool to call."
|
||||
)
|
||||
else:
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
if isinstance(tool_choice, str) and (
|
||||
tool_choice not in ("auto", "any", "none")
|
||||
):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
# TODO: Remove this update once 'any' is supported.
|
||||
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
|
@ -92,5 +92,6 @@ filterwarnings = [
|
||||
'ignore:The method `ChatGroq.with_structured_output` is in beta',
|
||||
# Maintain support for pydantic 1.X
|
||||
'default:The `dict` method is deprecated; use `model_dump` instead:DeprecationWarning',
|
||||
"ignore:tool_choice='any' is not currently supported. Converting to 'auto'.",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
@ -10,17 +10,10 @@ from langchain_groq import ChatGroq
|
||||
|
||||
|
||||
class TestGroqStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
) -> None:
|
||||
super().test_tool_message_histories_list_content(
|
||||
chat_model_class, chat_model_params, chat_model_has_tool_calling
|
||||
)
|
||||
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_histories_list_content(model)
|
||||
|
@ -2,14 +2,26 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
ChatModelUnitTests,
|
||||
Person,
|
||||
my_adder_tool,
|
||||
)
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
|
||||
class TestGroqStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
||||
"""Does not currently support tool_choice='any'."""
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -10,13 +9,10 @@ from langchain_mistralai import ChatMistralAI
|
||||
|
||||
|
||||
class TestMistralStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "mistral-large-latest",
|
||||
"temperature": 0,
|
||||
}
|
||||
return {"model": "mistral-large-latest", "temperature": 0}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,6 +9,6 @@ from langchain_mistralai import ChatMistralAI
|
||||
|
||||
|
||||
class TestMistralStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
@ -210,6 +215,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
).chat.completions
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
# As of 05/2024 Azure OpenAI doesn't support tool_choice="required".
|
||||
# TODO: Update this condition once tool_choice="required" is supported.
|
||||
if tool_choice in ("any", "required", True):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
f"Azure OpenAI does not currently support {tool_choice=}. Should "
|
||||
f"be one of 'auto', 'none', or the name of the tool to call."
|
||||
)
|
||||
else:
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
@ -345,6 +345,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -441,6 +443,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"stop": self.stop,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
@ -548,8 +551,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
@ -871,15 +872,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if tool_choice == "any":
|
||||
tool_choice = "required"
|
||||
elif isinstance(tool_choice, bool):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
"tool_choice=True can only be used when a single tool is "
|
||||
f"passed in, received {len(tools)} tools."
|
||||
)
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": formatted_tools[0]["function"]["name"]},
|
||||
}
|
||||
tool_choice = "required"
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
@ -1094,7 +1087,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
llm = self.bind_tools([schema], tool_choice=True)
|
||||
llm = self.bind_tools([schema], tool_choice="any")
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
|
@ -440,8 +440,6 @@ class BaseOpenAI(BaseLLM):
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
if params["max_tokens"] == -1:
|
||||
if len(prompts) != 1:
|
||||
|
@ -3,7 +3,6 @@
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -19,15 +18,15 @@ DEPLOYMENT_NAME = os.environ.get(
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"deployment_name": DEPLOYMENT_NAME,
|
||||
"openai_api_version": OPENAI_API_VERSION,
|
||||
"azure_endpoint": OPENAI_API_BASE,
|
||||
"openai_api_key": OPENAI_API_KEY,
|
||||
"api_key": OPENAI_API_KEY,
|
||||
}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
@ -99,13 +99,6 @@ def test_openai_stop_valid() -> None:
|
||||
assert first_output == second_output
|
||||
|
||||
|
||||
def test_openai_stop_error() -> None:
|
||||
"""Test openai stop logic on bad configuration."""
|
||||
llm = OpenAI(stop="3", temperature=0)
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
@ -2,23 +2,31 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"deployment_name": "test",
|
||||
"openai_api_version": "2021-10-01",
|
||||
"azure_endpoint": "https://test.azure.com",
|
||||
"openai_api_key": "test",
|
||||
}
|
||||
|
||||
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
||||
"""Does not currently support tool_choice='any'."""
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
64
libs/partners/together/poetry.lock
generated
64
libs/partners/together/poetry.lock
generated
@ -566,7 +566,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.4"
|
||||
version = "0.2.7"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -575,15 +575,12 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "^0.1.66"
|
||||
packaging = "^23.2"
|
||||
langsmith = "^0.1.75"
|
||||
packaging = ">=23.2,<25"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = "^8.1.0"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = ["jinja2 (>=3,<4)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../../core"
|
||||
@ -608,7 +605,7 @@ url = "../openai"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-standard-tests"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Standard tests for LangChain implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -625,13 +622,13 @@ url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.69"
|
||||
version = "0.1.77"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"},
|
||||
{file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"},
|
||||
{file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"},
|
||||
{file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -871,6 +868,51 @@ files = [
|
||||
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
|
||||
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.30.1"
|
||||
@ -1679,4 +1721,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "e7b21f556475be4c7133b74b6b0e138012bef9d47bc5bdc9709b24e55d9500f0"
|
||||
content-hash = "8a868382f8f3b693dccc1ce99428cdf9d6f8b6f77b3403c342c2bcc7b8526db9"
|
||||
|
@ -43,6 +43,12 @@ codespell = "^2.2.0"
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
# Support Python 3.8 and 3.12+.
|
||||
numpy = [
|
||||
{version = "^1", python = "<3.12"},
|
||||
{version = "^1.26.0", python = ">=3.12"}
|
||||
]
|
||||
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
@ -2,20 +2,17 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
class TestTogethertandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
class TestTogetherStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatTogether
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
}
|
||||
return {"model": "mistralai/Mistral-7B-Instruct-v0.1"}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_together import ChatTogether
|
||||
|
||||
|
||||
class TestTogetherStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatTogether
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "meta-llama/Llama-3-8b-chat-hf",
|
||||
}
|
||||
return {"model": "meta-llama/Llama-3-8b-chat-hf"}
|
||||
|
72
libs/partners/upstage/poetry.lock
generated
72
libs/partners/upstage/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
@ -340,7 +340,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.4"
|
||||
version = "0.2.7"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -349,15 +349,12 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = "^0.1.66"
|
||||
packaging = "^23.2"
|
||||
langsmith = "^0.1.75"
|
||||
packaging = ">=23.2,<25"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = "^8.1.0"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = ["jinja2 (>=3,<4)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../../core"
|
||||
@ -382,7 +379,7 @@ url = "../openai"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-standard-tests"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Standard tests for LangChain implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -399,13 +396,13 @@ url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.69"
|
||||
version = "0.1.77"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"},
|
||||
{file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"},
|
||||
{file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"},
|
||||
{file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -546,6 +543,51 @@ files = [
|
||||
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
|
||||
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.30.1"
|
||||
@ -725,26 +767,31 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:8d63d850d337c10fa49859697b9517e461b28e6d5d5a80121c72cc518eb0bae0"},
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:5f4a9ffabbcf8f19f6938484702e393ed6d423516f3e52c9d443162e3e42a884"},
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:c7dfddf19d2a8c734c5439692e87419c86f2621f1f205100355afb3bb43e5675"},
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:cf743d8c7f7261112153525ba7de1d954f9d563b875414814b27da35fb0df2cc"},
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-win32.whl", hash = "sha256:e30e8dec04c241739e0e9cf89b8a0317e991889dbca781e30abef228009c8cbd"},
|
||||
{file = "PyMuPDF-1.24.3-cp310-none-win_amd64.whl", hash = "sha256:3ceca02b143efe6b6f159d64a2f0e0aa32d0670894149a7f7144125fe2982da2"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:171313ee03e031437687cf856914eb61f66a5a76eddedc63048a63b69b00474b"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a421b332c257e70d9daed350cebefc043817ae0fd6b361734ee27f98288cc8c7"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:cc519230e352024111f065a1d32366eea4f1f1034e01515f10dbed709d9ab5ad"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:9df2d5e89810d3708fb8933dbc07505a57bfcb976a72bc559c7f0ccacc054c76"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-win32.whl", hash = "sha256:1de61f186c8367d1647d679bf6a4a77198751b378f9b67958a3b5d59adbc8c95"},
|
||||
{file = "PyMuPDF-1.24.3-cp311-none-win_amd64.whl", hash = "sha256:28e8c6c29de2951e29f98f17752eff0e80776fca7fe7ed5c7368363dff887c6c"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:34ab87e6d0f79eea9b632ed0401de20aff2622c95aa1a57fd17b49401c22c906"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ef2311861a3173072c489dc365827bb26f2c4487f969501afbbf1746478553ea"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:c4df2c50eba8fb8d8ffe63bd4099c57b562d11ed01dcf6cd922c4ea774212a34"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:401f2da8621f19bc302efa2a404c794b17982dea0e552b48ecd2c3f8d10b4707"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-win32.whl", hash = "sha256:ce4c07355b45e95803d1221cece01be58e32d1d9daec0d1ebc075ad03640c177"},
|
||||
{file = "PyMuPDF-1.24.3-cp312-none-win_amd64.whl", hash = "sha256:4f084f735e2e2d21f2c76de1abdcb44261889ec01a2842b57e69c89502f74b7a"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:14b2459e1a7e4dbf9ec6026e6056ccba6868bdfff1ffb346fd910108a61be095"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:f3572c2a85a12026637d485d6156b7f279a4aac7f474a341e5e06e8943ab2e0b"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:b0ceed71fa62eebd1bf8b55875cd6da7c2f09bbe2067218b68b5deb0d9feaa6e"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:e6b9100fa5194be1240b9998643ba122fcffd76149dccda3607455ccfed5fa2b"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-win32.whl", hash = "sha256:88e52a5c6d0375d27401c08fe7f7894f19db4af31169ba6deb6b3c1453f8b6e0"},
|
||||
{file = "PyMuPDF-1.24.3-cp38-none-win_amd64.whl", hash = "sha256:45c93944a14b19da3ee9b6d648e609f3ca35b8bca5c1cd16e6addcc59e7816d9"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0d4b6caf5ad25b7bd654ad4d42b8b3a00683b742bc5a81b8aeface79811386d5"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:6b52ee0460db88c71a06677353a0c768a8bb17718aa313462e9847ed1bf53f87"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:06a8a3226c9ec97c5e1df8cd16ec29b5df83d04ae88e9e0f5f4e25fcc1b997a1"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:ce113303b41adb74ae30ebd98761d9bd53477573e47566f05b3b7ff1c7354675"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-win32.whl", hash = "sha256:e4b4b2d5700c48a67da278476767488005408fac29426467b5bb437012197c0b"},
|
||||
{file = "PyMuPDF-1.24.3-cp39-none-win_amd64.whl", hash = "sha256:39acbac2854ef5b58f28c71bb19e84840771a771ec09cb33c4e66e2679c3b419"},
|
||||
@ -763,6 +810,7 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d2ccca660042896d4af479f979ec10674c5a0b3cd2d9ecb0011f08dc82380cce"},
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ad51d21086a16199684a3eebcb47d9c8460fc27e7bebae77f5fe64e8c34ebf34"},
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e7aab000d707c40e3254cd60152897b90952ed9a3567584d70974292f4912ce"},
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f39588fd2b7a63e2456df42cd8925c316202e0eb77d115d9c01ba032b2c9086f"},
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-win32.whl", hash = "sha256:0d606a10cb828cefc9f864bf67bc9d46e8007af55e643f022b59d378af4151a8"},
|
||||
{file = "PyMuPDFb-1.24.3-py3-none-win_amd64.whl", hash = "sha256:e88289bd4b4afe5966a028774b302f37d4b51dad5c5e6720dd04524910db6c6e"},
|
||||
@ -1304,4 +1352,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "0073172ce2312055480e9ff47dc99ce7dfd6809208ad5ea4cee5ecf7f12eef56"
|
||||
content-hash = "b21648a1fdc08f901c82fb3b4773682f0a4b83b03b97ae1ddbd0834b730ff8c2"
|
||||
|
@ -43,6 +43,12 @@ codespell = "^2.2.0"
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
# Support Python 3.8 and 3.12+.
|
||||
numpy = [
|
||||
{version = "^1", python = "<3.12"},
|
||||
{version = "^1.26.0", python = ">=3.12"}
|
||||
]
|
||||
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "solar-1-mini-chat",
|
||||
}
|
||||
return {"model": "solar-1-mini-chat"}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "solar-1-mini-chat",
|
||||
}
|
||||
return {"model": "solar-1-mini-chat"}
|
||||
|
@ -1,74 +1,31 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
ChatModelTests,
|
||||
my_adder_tool,
|
||||
)
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
|
||||
@tool
|
||||
def my_adder_tool(a: int, b: int) -> int:
|
||||
"""Takes two integers, a and b, and returns their sum."""
|
||||
return a + b
|
||||
|
||||
|
||||
class ChatModelIntegrationTests(ABC):
|
||||
@abstractmethod
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
...
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {}
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_has_tool_calling(
|
||||
self, chat_model_class: Type[BaseChatModel]
|
||||
) -> bool:
|
||||
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_has_structured_output(
|
||||
self, chat_model_class: Type[BaseChatModel]
|
||||
) -> bool:
|
||||
return (
|
||||
chat_model_class.with_structured_output
|
||||
is not BaseChatModel.with_structured_output
|
||||
)
|
||||
|
||||
def test_invoke(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_invoke(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
async def test_ainvoke(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
async def test_ainvoke(self, model: BaseChatModel) -> None:
|
||||
result = await model.ainvoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_stream(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
num_tokens = 0
|
||||
for token in model.stream("Hello"):
|
||||
assert token is not None
|
||||
@ -76,10 +33,7 @@ class ChatModelIntegrationTests(ABC):
|
||||
num_tokens += len(token.content)
|
||||
assert num_tokens > 0
|
||||
|
||||
async def test_astream(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
num_tokens = 0
|
||||
async for token in model.astream("Hello"):
|
||||
assert token is not None
|
||||
@ -87,10 +41,7 @@ class ChatModelIntegrationTests(ABC):
|
||||
num_tokens += len(token.content)
|
||||
assert num_tokens > 0
|
||||
|
||||
def test_batch(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
def test_batch(self, model: BaseChatModel) -> None:
|
||||
batch_results = model.batch(["Hello", "Hey"])
|
||||
assert batch_results is not None
|
||||
assert isinstance(batch_results, list)
|
||||
@ -101,10 +52,7 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
async def test_abatch(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
async def test_abatch(self, model: BaseChatModel) -> None:
|
||||
batch_results = await model.abatch(["Hello", "Hey"])
|
||||
assert batch_results is not None
|
||||
assert isinstance(batch_results, list)
|
||||
@ -115,14 +63,11 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_conversation(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
def test_conversation(self, model: BaseChatModel) -> None:
|
||||
messages = [
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="hello"),
|
||||
HumanMessage(content="how are you"),
|
||||
HumanMessage("hello"),
|
||||
AIMessage("hello"),
|
||||
HumanMessage("how are you"),
|
||||
]
|
||||
result = model.invoke(messages)
|
||||
assert result is not None
|
||||
@ -130,10 +75,9 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_usage_metadata(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
if not self.returns_usage_metadata:
|
||||
pytest.skip("Not implemented.")
|
||||
result = model.invoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
@ -142,39 +86,35 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.usage_metadata["output_tokens"], int)
|
||||
assert isinstance(result.usage_metadata["total_tokens"], int)
|
||||
|
||||
def test_stop_sequence(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("hi", stop=["you"])
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
model = chat_model_class(**chat_model_params, stop=["you"])
|
||||
result = model.invoke("hi")
|
||||
custom_model = self.chat_model_class(
|
||||
**{**self.chat_model_params, "stop": ["you"]}
|
||||
)
|
||||
result = custom_model.invoke("hi")
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
def test_tool_message_histories_string_content(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
model: BaseChatModel,
|
||||
) -> None:
|
||||
"""
|
||||
Test that message histories are compatible with string tool contents
|
||||
(e.g. OpenAI).
|
||||
"""
|
||||
if not chat_model_has_tool_calling:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model = chat_model_class(**chat_model_params)
|
||||
model_with_tools = model.bind_tools([my_adder_tool])
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": "1", "b": "2"}
|
||||
|
||||
messages_string_content = [
|
||||
HumanMessage(content="What is 1 + 2"),
|
||||
HumanMessage("What is 1 + 2"),
|
||||
# string content (e.g. OpenAI)
|
||||
AIMessage(
|
||||
content="",
|
||||
"",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": function_name,
|
||||
@ -184,8 +124,8 @@ class ChatModelIntegrationTests(ABC):
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
json.dumps({"result": 3}),
|
||||
name=function_name,
|
||||
content=json.dumps({"result": 3}),
|
||||
tool_call_id="abc123",
|
||||
),
|
||||
]
|
||||
@ -194,26 +134,23 @@ class ChatModelIntegrationTests(ABC):
|
||||
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
model: BaseChatModel,
|
||||
) -> None:
|
||||
"""
|
||||
Test that message histories are compatible with list tool contents
|
||||
(e.g. Anthropic).
|
||||
"""
|
||||
if not chat_model_has_tool_calling:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model = chat_model_class(**chat_model_params)
|
||||
model_with_tools = model.bind_tools([my_adder_tool])
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": 1, "b": 2}
|
||||
|
||||
messages_list_content = [
|
||||
HumanMessage(content="What is 1 + 2"),
|
||||
HumanMessage("What is 1 + 2"),
|
||||
# List content (e.g., Anthropic)
|
||||
AIMessage(
|
||||
content=[
|
||||
[
|
||||
{"type": "text", "text": "some text"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
@ -231,8 +168,8 @@ class ChatModelIntegrationTests(ABC):
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
json.dumps({"result": 3}),
|
||||
name=function_name,
|
||||
content=json.dumps({"result": 3}),
|
||||
tool_call_id="abc123",
|
||||
),
|
||||
]
|
||||
@ -241,25 +178,22 @@ class ChatModelIntegrationTests(ABC):
|
||||
|
||||
def test_structured_few_shot_examples(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
model: BaseChatModel,
|
||||
) -> None:
|
||||
"""
|
||||
Test that model can process few-shot examples with tool calls.
|
||||
"""
|
||||
if not chat_model_has_tool_calling:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model = chat_model_class(**chat_model_params)
|
||||
model_with_tools = model.bind_tools([my_adder_tool])
|
||||
model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any")
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": 1, "b": 2}
|
||||
function_result = json.dumps({"result": 3})
|
||||
|
||||
messages_string_content = [
|
||||
HumanMessage(content="What is 1 + 2"),
|
||||
HumanMessage("What is 1 + 2"),
|
||||
AIMessage(
|
||||
content="",
|
||||
"",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": function_name,
|
||||
@ -269,12 +203,12 @@ class ChatModelIntegrationTests(ABC):
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
function_result,
|
||||
name=function_name,
|
||||
content=function_result,
|
||||
tool_call_id="abc123",
|
||||
),
|
||||
AIMessage(content=function_result),
|
||||
HumanMessage(content="What is 3 + 4"),
|
||||
AIMessage(function_result),
|
||||
HumanMessage("What is 3 + 4"),
|
||||
]
|
||||
result_string_content = model_with_tools.invoke(messages_string_content)
|
||||
assert isinstance(result_string_content, AIMessage)
|
||||
|
@ -1,13 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type
|
||||
from typing import Any, List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
@ -18,81 +21,105 @@ def my_adder_tool(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
|
||||
class ChatModelUnitTests(ABC):
|
||||
class ChatModelTests(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
...
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {}
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_has_tool_calling(
|
||||
self, chat_model_class: Type[BaseChatModel]
|
||||
) -> bool:
|
||||
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
||||
@property
|
||||
def standard_chat_model_params(self) -> dict:
|
||||
return {
|
||||
"temperature": 0,
|
||||
"max_tokens": 100,
|
||||
"timeout": 60,
|
||||
"stop_sequences": [],
|
||||
"max_retries": 2,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_has_structured_output(
|
||||
self, chat_model_class: Type[BaseChatModel]
|
||||
) -> bool:
|
||||
def model(self) -> BaseChatModel:
|
||||
return self.chat_model_class(
|
||||
**{**self.standard_chat_model_params, **self.chat_model_params}
|
||||
)
|
||||
|
||||
@property
|
||||
def has_tool_calling(self) -> bool:
|
||||
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
return (
|
||||
chat_model_class.with_structured_output
|
||||
self.chat_model_class.with_structured_output
|
||||
is not BaseChatModel.with_structured_output
|
||||
)
|
||||
|
||||
def test_chat_model_init(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_video_inputs(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def returns_usage_metadata(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ChatModelUnitTests(ChatModelTests):
|
||||
@property
|
||||
def standard_chat_model_params(self) -> dict:
|
||||
params = super().standard_chat_model_params
|
||||
params["api_key"] = "test"
|
||||
return params
|
||||
|
||||
def test_init(self) -> None:
|
||||
model = self.chat_model_class(
|
||||
**{**self.standard_chat_model_params, **self.chat_model_params}
|
||||
)
|
||||
assert model is not None
|
||||
|
||||
def test_chat_model_init_api_key(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
params = {**chat_model_params, "api_key": "test"}
|
||||
model = chat_model_class(**params) # type: ignore
|
||||
assert model is not None
|
||||
|
||||
def test_chat_model_init_streaming(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(streaming=True, **chat_model_params) # type: ignore
|
||||
assert model is not None
|
||||
|
||||
def test_chat_model_bind_tool_pydantic(
|
||||
def test_init_streaming(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
) -> None:
|
||||
if not chat_model_has_tool_calling:
|
||||
model = self.chat_model_class(
|
||||
**{
|
||||
**self.standard_chat_model_params,
|
||||
**self.chat_model_params,
|
||||
"streaming": True,
|
||||
}
|
||||
)
|
||||
assert model is not None
|
||||
|
||||
def test_bind_tool_pydantic(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
) -> None:
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
model = chat_model_class(**chat_model_params)
|
||||
tool_model = model.bind_tools(
|
||||
[Person, Person.schema(), my_adder_tool], tool_choice="any"
|
||||
)
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
||||
assert hasattr(model, "bind_tools")
|
||||
tool_model = model.bind_tools([Person])
|
||||
assert tool_model is not None
|
||||
|
||||
def test_chat_model_with_structured_output(
|
||||
@pytest.mark.parametrize("schema", [Person, Person.schema()])
|
||||
def test_with_structured_output(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_structured_output: bool,
|
||||
model: BaseChatModel,
|
||||
schema: Any,
|
||||
) -> None:
|
||||
if not chat_model_has_structured_output:
|
||||
if not self.has_structured_output:
|
||||
return
|
||||
|
||||
model = chat_model_class(**chat_model_params)
|
||||
assert model is not None
|
||||
assert model.with_structured_output(Person) is not None
|
||||
assert model.with_structured_output(schema) is not None
|
||||
|
||||
def test_standard_params(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
@ -101,7 +128,6 @@ class ChatModelUnitTests(ABC):
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
||||
model = chat_model_class(**chat_model_params)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
@ -109,7 +135,9 @@ class ChatModelUnitTests(ABC):
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
# Test optional params
|
||||
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params)
|
||||
model = self.chat_model_class(
|
||||
max_tokens=10, stop=["test"], **self.chat_model_params
|
||||
)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
|
Loading…
Reference in New Issue
Block a user