standard-tests[patch]: Update chat model standard tests (#22378)

- Refactor standard test classes to make them easier to configure
- Update openai to support stop_sequences init param
- Update groq to support stop_sequences init param
- Update fireworks to support max_retries init param
- Update ChatModel.bind_tools to type tool_choice
- Update groq to handle tool_choice="any". **this may be controversial**

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur
2024-06-17 13:37:41 -07:00
committed by GitHub
parent 14f0cdad58
commit d96f67b06f
31 changed files with 383 additions and 378 deletions

View File

@@ -296,6 +296,8 @@ class ChatFireworks(BaseChatModel):
"""Model name to use."""
temperature: float = 0.0
"""What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
@@ -314,8 +316,8 @@ class ChatFireworks(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
@@ -360,6 +362,9 @@ class ChatFireworks(BaseChatModel):
values["client"] = Fireworks(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = AsyncFireworks(**client_params).chat.completions
if values["max_retries"]:
values["client"]._max_retries = values["max_retries"]
values["async_client"]._max_retries = values["max_retries"]
return values
@property