community: fix CI (#21766)

This commit is contained in:
ccurme
2024-05-16 11:41:03 -04:00
committed by GitHub
parent dda5a9c97a
commit 19e6bf814b
2 changed files with 13 additions and 11 deletions

View File

@@ -151,9 +151,11 @@ class GenerateUsername(BaseModel):
def test_tool_use() -> None:
llm = ChatTongyi(model="qwen-turbo", temperature=0)
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
msgs: List = [
HumanMessage(content="Sally has green hair, what would her username be?")
]
ai_msg = llm_with_tool.invoke(msgs)
# assert ai_msg is None
# ai_msg.content = " "
@@ -165,8 +167,8 @@ def test_tool_use() -> None:
assert "args" in tool_call
tool_msg = ToolMessage(
"sally_green_hair",
tool_call_id=ai_msg.tool_calls[0]["id"],
content="sally_green_hair",
tool_call_id=ai_msg.tool_calls[0]["id"], # type: ignore
name=ai_msg.tool_calls[0]["name"],
)
msgs.extend([ai_msg, tool_msg])
@@ -184,7 +186,7 @@ def test_tool_use() -> None:
assert isinstance(gathered, AIMessageChunk)
streaming_tool_msg = ToolMessage(
"sally_green_hair",
content="sally_green_hair",
name=tool_call["name"],
tool_call_id=tool_call["id"] if tool_call["id"] else " ",
)
@@ -194,10 +196,10 @@ def test_tool_use() -> None:
def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatTongyi(model="qwen-turbo", temperature=0)
llm = ChatTongyi(model="qwen-turbo", temperature=0) # type: ignore
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage("Sally has green hair, what would her username be?"),
HumanMessage(content="Sally has green hair, what would her username be?"),
AIMessage(
content=" ",
tool_calls=[
@@ -208,7 +210,7 @@ def test_manual_tool_call_msg() -> None:
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
ToolMessage(content="sally_green_hair", tool_call_id="foo"),
]
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
assert output.content