diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index d5599c54ab7..50baef64d14 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,6 +1,5 @@ """Test ChatOpenAI chat model.""" - -from typing import Any, Optional, cast +from typing import Any, List, Optional, cast import pytest from langchain_core.callbacks import CallbackManager @@ -10,6 +9,7 @@ from langchain_core.messages import ( BaseMessageChunk, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.outputs import ( ChatGeneration, @@ -470,6 +470,25 @@ async def test_async_response_metadata_streaming() -> None: assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] +class GenerateUsername(BaseModel): + "Get a username based on someone's name and hair color." + + name: str + hair_color: str + + +def test_tool_use() -> None: + llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) + msgs: List = [HumanMessage("Sally has green hair, what would her username be?")] + ai_msg = llm_with_tool.invoke(msgs) + tool_msg = ToolMessage( + "sally_green_hair", tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"] + ) + msgs.extend([ai_msg, tool_msg]) + llm_with_tool.invoke(msgs) + + def test_openai_structured_output() -> None: class MyModel(BaseModel): """A Person"""