diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index b83cf5127ba..0774faf5e26 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -33,14 +33,18 @@ class TestAI21J2(BaseTestAI21): "model": "j2-ultra", } - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") + @pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.") def test_stream(self, model: BaseChatModel) -> None: super().test_stream(model) - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") + @pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.") async def test_astream(self, model: BaseChatModel) -> None: await super().test_astream(model) + @pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.") + def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: + super().test_usage_metadata_streaming(model) + class TestAI21Jamba(BaseTestAI21): @property diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py index 8cddaac726e..41c20a942ac 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py @@ -3,6 +3,7 @@ import os from typing import Type +import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -30,3 +31,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests): "azure_endpoint": OPENAI_API_BASE, "api_key": OPENAI_API_KEY, } + + @pytest.mark.xfail(reason="Not yet supported.") + def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: + super().test_usage_metadata_streaming(model) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index 0227a07f787..888998b2371 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -18,7 +18,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: - return {"model": "gpt-4o"} + return {"model": "gpt-4o", "stream_usage": True} @property def supports_image_inputs(self) -> bool: diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index b91f9e784c3..4397725837d 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -23,3 +23,7 @@ class TestTogetherStandard(ChatModelIntegrationTests): @pytest.mark.xfail(reason=("May not call a tool.")) def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: super().test_tool_calling_with_no_arguments(model) + + @pytest.mark.xfail(reason="Not yet supported.") + def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: + super().test_usage_metadata_streaming(model) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index a0099844362..2aaba72c030 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -4,7 +4,7 @@ from typing import List, Optional import httpx import pytest -from langchain_core.language_models import BaseChatModel +from langchain_core.language_models import BaseChatModel, GenericFakeChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -14,6 +14,8 @@ from langchain_core.messages import ( SystemMessage, ToolMessage, ) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool @@ -129,6 +131,19 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int) + def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: + if not self.returns_usage_metadata: + pytest.skip("Not implemented.") + full: Optional[BaseMessageChunk] = None + for chunk in model.stream("Hello"): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert isinstance(full.usage_metadata["input_tokens"], int) + assert isinstance(full.usage_metadata["output_tokens"], int) + assert isinstance(full.usage_metadata["total_tokens"], int) + def test_stop_sequence(self, model: BaseChatModel) -> None: result = model.invoke("hi", stop=["you"]) assert isinstance(result, AIMessage) @@ -171,6 +186,23 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(full, AIMessage) _validate_tool_call_message_no_args(full) + def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: + if not self.has_tool_calling: + pytest.skip("Test requires tool calling.") + + prompt = ChatPromptTemplate.from_messages( + [("human", "Hello. Please respond in the style of {answer_style}.")] + ) + llm = GenericFakeChatModel(messages=iter(["hello matey"])) + chain = prompt | llm | StrOutputParser() + model_with_tools = model.bind_tools([chain.as_tool()]) + query = "Using the tool, ask a Pirate how it would say hello." + result = model_with_tools.invoke(query) + assert isinstance(result, AIMessage) + assert result.tool_calls + tool_call = result.tool_calls[0] + assert tool_call["args"].get("answer_style") + def test_structured_output(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.")