mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
fix(anthropic): support eager_input_streaming (#35779)
This commit is contained in:
@@ -166,6 +166,7 @@ _ANTHROPIC_EXTRA_FIELDS: set[str] = {
|
||||
"allowed_callers",
|
||||
"cache_control",
|
||||
"defer_loading",
|
||||
"eager_input_streaming",
|
||||
"input_examples",
|
||||
}
|
||||
"""Valid Anthropic-specific extra fields"""
|
||||
|
||||
@@ -16,7 +16,7 @@ from blockbuster import blockbuster_ctx
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from pydantic import BaseModel, Field, SecretStr, ValidationError
|
||||
@@ -2007,9 +2007,9 @@ def test_tool_search_with_deferred_tools() -> None:
|
||||
|
||||
# Find the calculator tool in the payload
|
||||
calculator_tool = None
|
||||
for tool in payload["tools"]:
|
||||
if isinstance(tool, dict) and tool.get("name") == "calculator":
|
||||
calculator_tool = tool
|
||||
for tool_ in payload["tools"]:
|
||||
if isinstance(tool_, dict) and tool_.get("name") == "calculator":
|
||||
calculator_tool = tool_
|
||||
break
|
||||
|
||||
assert calculator_tool is not None
|
||||
@@ -2264,7 +2264,6 @@ def test_output_config_without_effort() -> None:
|
||||
|
||||
def test_extras_with_defer_loading() -> None:
|
||||
"""Test that extras with `defer_loading` are merged into tool definitions."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool(extras={"defer_loading": True})
|
||||
def get_weather(location: str) -> str:
|
||||
@@ -2293,7 +2292,6 @@ def test_extras_with_defer_loading() -> None:
|
||||
|
||||
def test_extras_with_cache_control() -> None:
|
||||
"""Test that extras with `cache_control` are merged into tool definitions."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool(extras={"cache_control": {"type": "ephemeral"}})
|
||||
def search_files(query: str) -> str:
|
||||
@@ -2318,9 +2316,31 @@ def test_extras_with_cache_control() -> None:
|
||||
assert search_tool.get("cache_control") == {"type": "ephemeral"}
|
||||
|
||||
|
||||
def test_extras_with_fine_grained_streaming() -> None:
|
||||
@tool(extras={"eager_input_streaming": True})
|
||||
def tell_story(story: str) -> None:
|
||||
"""Tell a story."""
|
||||
|
||||
model = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
||||
model_with_tools = model.bind_tools([tell_story])
|
||||
|
||||
payload = model_with_tools._get_request_payload( # type: ignore[attr-defined]
|
||||
"test",
|
||||
**model_with_tools.kwargs, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
tell_story_tool = None
|
||||
for tool_def in payload["tools"]:
|
||||
if isinstance(tool_def, dict) and tool_def.get("name") == "tell_story":
|
||||
tell_story_tool = tool_def
|
||||
break
|
||||
|
||||
assert tell_story_tool is not None
|
||||
assert tell_story_tool.get("eager_input_streaming") is True
|
||||
|
||||
|
||||
def test_extras_with_input_examples() -> None:
|
||||
"""Test that extras with `input_examples` are merged into tool definitions."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool(
|
||||
extras={
|
||||
@@ -2363,7 +2383,6 @@ def test_extras_with_input_examples() -> None:
|
||||
|
||||
def test_extras_with_multiple_fields() -> None:
|
||||
"""Test that multiple extra fields can be specified together."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool(
|
||||
extras={
|
||||
|
||||
Reference in New Issue
Block a user