mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
standard-tests[patch]: Update chat model standard tests (#22378)
- Refactor standard test classes to make them easier to configure - Update openai to support stop_sequences init param - Update groq to support stop_sequences init param - Update fireworks to support max_retries init param - Update ChatModel.bind_tools to type tool_choice - Update groq to handle tool_choice="any". **this may be controversial** --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -2,23 +2,31 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return AzureChatOpenAI
|
||||
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"deployment_name": "test",
|
||||
"openai_api_version": "2021-10-01",
|
||||
"azure_endpoint": "https://test.azure.com",
|
||||
"openai_api_key": "test",
|
||||
}
|
||||
|
||||
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
||||
"""Does not currently support tool_choice='any'."""
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
@@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class TestOpenAIStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatOpenAI
|
||||
|
Reference in New Issue
Block a user