diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 43904445a74..57b8e01792b 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -166,6 +166,7 @@ _ANTHROPIC_EXTRA_FIELDS: set[str] = { "allowed_callers", "cache_control", "defer_loading", + "eager_input_streaming", "input_examples", } """Valid Anthropic-specific extra fields""" diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 1fa7a4b5cf6..02862bdef74 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -16,7 +16,7 @@ from blockbuster import blockbuster_ctx from langchain_core.exceptions import ContextOverflowError from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableBinding -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, tool from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run from pydantic import BaseModel, Field, SecretStr, ValidationError @@ -2007,9 +2007,9 @@ def test_tool_search_with_deferred_tools() -> None: # Find the calculator tool in the payload calculator_tool = None - for tool in payload["tools"]: - if isinstance(tool, dict) and tool.get("name") == "calculator": - calculator_tool = tool + for tool_ in payload["tools"]: + if isinstance(tool_, dict) and tool_.get("name") == "calculator": + calculator_tool = tool_ break assert calculator_tool is not None @@ -2264,7 +2264,6 @@ def test_output_config_without_effort() -> None: def test_extras_with_defer_loading() -> None: """Test that extras with `defer_loading` are merged into tool definitions.""" - from langchain_core.tools import tool @tool(extras={"defer_loading": True}) def get_weather(location: str) -> str: @@ -2293,7 +2292,6 @@ def test_extras_with_defer_loading() -> None: def test_extras_with_cache_control() -> None: """Test that extras with `cache_control` are merged into tool definitions.""" - from langchain_core.tools import tool @tool(extras={"cache_control": {"type": "ephemeral"}}) def search_files(query: str) -> str: @@ -2318,9 +2316,31 @@ def test_extras_with_cache_control() -> None: assert search_tool.get("cache_control") == {"type": "ephemeral"} +def test_extras_with_fine_grained_streaming() -> None: + @tool(extras={"eager_input_streaming": True}) + def tell_story(story: str) -> None: + """Tell a story.""" + + model = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] + model_with_tools = model.bind_tools([tell_story]) + + payload = model_with_tools._get_request_payload( # type: ignore[attr-defined] + "test", + **model_with_tools.kwargs, # type: ignore[attr-defined] + ) + + tell_story_tool = None + for tool_def in payload["tools"]: + if isinstance(tool_def, dict) and tool_def.get("name") == "tell_story": + tell_story_tool = tool_def + break + + assert tell_story_tool is not None + assert tell_story_tool.get("eager_input_streaming") is True + + def test_extras_with_input_examples() -> None: """Test that extras with `input_examples` are merged into tool definitions.""" - from langchain_core.tools import tool @tool( extras={ @@ -2363,7 +2383,6 @@ def test_extras_with_input_examples() -> None: def test_extras_with_multiple_fields() -> None: """Test that multiple extra fields can be specified together.""" - from langchain_core.tools import tool @tool( extras={