mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
standard-tests[patch]: Update chat model standard tests (#22378)
- Refactor standard test classes to make them easier to configure - Update openai to support stop_sequences init param - Update groq to support stop_sequences init param - Update fireworks to support max_retries init param - Update ChatModel.bind_tools to type tool_choice - Update groq to handle tool_choice="any". **this may be controversial** --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
@@ -210,6 +215,27 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
).chat.completions
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
# As of 05/2024 Azure OpenAI doesn't support tool_choice="required".
|
||||
# TODO: Update this condition once tool_choice="required" is supported.
|
||||
if tool_choice in ("any", "required", True):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
f"Azure OpenAI does not currently support {tool_choice=}. Should "
|
||||
f"be one of 'auto', 'none', or the name of the tool to call."
|
||||
)
|
||||
else:
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
Reference in New Issue
Block a user