1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-06 07:38:50 +00:00

together, standard-tests: specify tool_choice in standard tests ()

Here we allow standard tests to specify a value for `tool_choice` via a
`tool_choice_value` property, which defaults to None.

Chat models [available in
Together](https://docs.together.ai/docs/chat-models) have issues passing
standard tool calling tests:
- llama 3.1 models currently [appear to rely on user-side
parsing](https://docs.together.ai/docs/llama-3-function-calling) in
Together;
- Mixtral-8x7B and Mistral-7B (currently tested) consistently do not
call tools in some tests.

Specifying tool_choice also lets us remove an existing `xfail` and use a
smaller model in Groq tests.
This commit is contained in:
ccurme 2024-08-19 16:37:36 -04:00 committed by GitHub
parent 015ab91b83
commit c5bf114c0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 83 additions and 10 deletions
libs
partners
groq/tests/integration_tests
mistralai/tests/integration_tests
together/tests/integration_tests
standard-tests/langchain_standard_tests
integration_tests
unit_tests

View File

@ -14,6 +14,7 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langchain_groq import ChatGroq
from tests.unit_tests.fake.callbacks import (
@ -393,6 +394,42 @@ def test_json_mode_structured_output() -> None:
assert len(result.punchline) != 0
def test_tool_calling_no_arguments() -> None:
# Note: this is a variant of a test in langchain_standard_tests
# that as of 2024-08-19 fails with "Failed to call a function. Please
# adjust your prompt." when `tool_choice="any"` is specified, but
# passes when `tool_choice` is not specified.
model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg]
@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5
model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"
# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Type
from typing import Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
@ -28,11 +28,22 @@ class TestGroqLlama(BaseTestGroq):
@property
def chat_model_params(self) -> dict:
return {
"model": "llama-3.1-70b-versatile",
"model": "llama-3.1-8b-instant",
"temperature": 0,
"rate_limiter": rate_limiter,
}
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Type
from typing import Optional, Type
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
@ -18,3 +18,8 @@ class TestMistralStandard(ChatModelIntegrationTests):
@property
def chat_model_params(self) -> dict:
return {"model": "mistral-large-latest", "temperature": 0}
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Type
from typing import Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
@ -28,9 +28,10 @@ class TestTogetherStandard(ChatModelIntegrationTests):
"rate_limiter": rate_limiter,
}
@pytest.mark.xfail(reason=("May not call a tool."))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:

View File

@ -170,7 +170,11 @@ class ChatModelIntegrationTests(ChatModelTests):
def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
@ -188,7 +192,13 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function_no_args])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function_no_args"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools(
[magic_function_no_args], tool_choice=tool_choice
)
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)
@ -212,7 +222,11 @@ class ChatModelIntegrationTests(ChatModelTests):
name="greeting_generator",
description="Generate a greeting in a particular style of speaking.",
)
model_with_tools = model.bind_tools([tool_])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "greeting_generator"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)

View File

@ -96,6 +96,11 @@ class ChatModelTests(BaseStandardTests):
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return None
@property
def has_structured_output(self) -> bool:
return (