mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-06 07:38:50 +00:00
together, standard-tests: specify tool_choice in standard tests (#25548)
Here we allow standard tests to specify a value for `tool_choice` via a `tool_choice_value` property, which defaults to None. Chat models [available in Together](https://docs.together.ai/docs/chat-models) have issues passing standard tool calling tests: - llama 3.1 models currently [appear to rely on user-side parsing](https://docs.together.ai/docs/llama-3-function-calling) in Together; - Mixtral-8x7B and Mistral-7B (currently tested) consistently do not call tools in some tests. Specifying tool_choice also lets us remove an existing `xfail` and use a smaller model in Groq tests.
This commit is contained in:
parent
015ab91b83
commit
c5bf114c0f
libs
partners
groq/tests/integration_tests
mistralai/tests/integration_tests
together/tests/integration_tests
standard-tests/langchain_standard_tests
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_groq import ChatGroq
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@ -393,6 +394,42 @@ def test_json_mode_structured_output() -> None:
|
||||
assert len(result.punchline) != 0
|
||||
|
||||
|
||||
def test_tool_calling_no_arguments() -> None:
|
||||
# Note: this is a variant of a test in langchain_standard_tests
|
||||
# that as of 2024-08-19 fails with "Failed to call a function. Please
|
||||
# adjust your prompt." when `tool_choice="any"` is specified, but
|
||||
# passes when `tool_choice` is not specified.
|
||||
model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
@tool
|
||||
def magic_function_no_args() -> int:
|
||||
"""Calculates a magic function."""
|
||||
return 5
|
||||
|
||||
model_with_tools = model.bind_tools([magic_function_no_args])
|
||||
query = "What is the value of magic_function()? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call["name"] == "magic_function_no_args"
|
||||
assert tool_call["args"] == {}
|
||||
assert tool_call["id"] is not None
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
assert len(full.tool_calls) == 1
|
||||
tool_call = full.tool_calls[0]
|
||||
assert tool_call["name"] == "magic_function_no_args"
|
||||
assert tool_call["args"] == {}
|
||||
assert tool_call["id"] is not None
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
|
||||
# Groq does not currently support N > 1
|
||||
# @pytest.mark.scheduled
|
||||
# def test_chat_multiple_completions() -> None:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -28,11 +28,22 @@ class TestGroqLlama(BaseTestGroq):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
"model": "llama-3.1-8b-instant",
|
||||
"temperature": 0,
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
|
||||
)
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
@ -18,3 +18,8 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "mistral-large-latest", "temperature": 0}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -28,9 +28,10 @@ class TestTogetherStandard(ChatModelIntegrationTests):
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason=("May not call a tool."))
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "tool_name"
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
|
@ -170,7 +170,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([magic_function])
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "magic_function"
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
|
||||
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
@ -188,7 +192,13 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
model_with_tools = model.bind_tools([magic_function_no_args])
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "magic_function_no_args"
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
model_with_tools = model.bind_tools(
|
||||
[magic_function_no_args], tool_choice=tool_choice
|
||||
)
|
||||
query = "What is the value of magic_function()? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
_validate_tool_call_message_no_args(result)
|
||||
@ -212,7 +222,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
name="greeting_generator",
|
||||
description="Generate a greeting in a particular style of speaking.",
|
||||
)
|
||||
model_with_tools = model.bind_tools([tool_])
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "greeting_generator"
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
|
||||
query = "Using the tool, generate a Pirate greeting."
|
||||
result = model_with_tools.invoke(query)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
@ -96,6 +96,11 @@ class ChatModelTests(BaseStandardTests):
|
||||
def has_tool_calling(self) -> bool:
|
||||
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user