mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
standard-tests[patch]: Update chat model standard tests (#22378)
- Refactor standard test classes to make them easier to configure - Update openai to support stop_sequences init param - Update groq to support stop_sequences init param - Update fireworks to support max_retries init param - Update ChatModel.bind_tools to type tool_choice - Update groq to handle tool_choice="any". **this may be controversial** --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
@@ -210,6 +215,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
).chat.completions
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
# As of 05/2024 Azure OpenAI doesn't support tool_choice="required".
|
||||
# TODO: Update this condition once tool_choice="required" is supported.
|
||||
if tool_choice in ("any", "required", True):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
f"Azure OpenAI does not currently support {tool_choice=}. Should "
|
||||
f"be one of 'auto', 'none', or the name of the tool to call."
|
||||
)
|
||||
else:
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
@@ -345,6 +345,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -441,6 +443,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"stop": self.stop,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
@@ -548,8 +551,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
@@ -871,15 +872,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if tool_choice == "any":
|
||||
tool_choice = "required"
|
||||
elif isinstance(tool_choice, bool):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
"tool_choice=True can only be used when a single tool is "
|
||||
f"passed in, received {len(tools)} tools."
|
||||
)
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": formatted_tools[0]["function"]["name"]},
|
||||
}
|
||||
tool_choice = "required"
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
@@ -1094,7 +1087,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
llm = self.bind_tools([schema], tool_choice=True)
|
||||
llm = self.bind_tools([schema], tool_choice="any")
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
|
@@ -440,8 +440,6 @@ class BaseOpenAI(BaseLLM):
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
if params["max_tokens"] == -1:
|
||||
if len(prompts) != 1:
|
||||
|
Reference in New Issue
Block a user