Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
b343799825 standard-tests[patch]: async variations of all tests 2024-08-16 11:19:32 -07:00

View File

@@ -133,7 +133,7 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(result.content, str)
assert len(result.content) > 0
def test_usage_metadata(self, model: BaseChatModel) -> None:
async def test_usage_metadata(self, model: BaseChatModel) -> None:
if not self.returns_usage_metadata:
pytest.skip("Not implemented.")
result = model.invoke("Hello")
@@ -144,11 +144,19 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
result = await model.ainvoke("Hello")
assert result is not None
assert isinstance(result, AIMessage)
assert result.usage_metadata is not None
assert isinstance(result.usage_metadata["input_tokens"], int)
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)
async def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
if not self.returns_usage_metadata:
pytest.skip("Not implemented.")
full: Optional[BaseMessageChunk] = None
for chunk in model.stream("Hello"):
async for chunk in model.astream("Hello"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
@@ -157,7 +165,7 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full.usage_metadata["output_tokens"], int)
assert isinstance(full.usage_metadata["total_tokens"], int)
def test_stop_sequence(self, model: BaseChatModel) -> None:
async def test_stop_sequence(self, model: BaseChatModel) -> None:
result = model.invoke("hi", stop=["you"])
assert isinstance(result, AIMessage)
@@ -166,8 +174,10 @@ class ChatModelIntegrationTests(ChatModelTests):
)
result = custom_model.invoke("hi")
assert isinstance(result, AIMessage)
result = await custom_model.ainvoke("hi")
assert isinstance(result, AIMessage)
def test_tool_calling(self, model: BaseChatModel) -> None:
async def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
@@ -177,6 +187,9 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model_with_tools.invoke(query)
_validate_tool_call_message(result)
result = await model_with_tools.ainvoke(query)
_validate_tool_call_message(result)
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
@@ -184,7 +197,13 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
full = None
async for chunk in model_with_tools.astream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
async def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
@@ -193,13 +212,22 @@ class ChatModelIntegrationTests(ChatModelTests):
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)
result = await model_with_tools.ainvoke(query)
_validate_tool_call_message_no_args(result)
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)
def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
full = None
async for chunk in model_with_tools.astream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)
async def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
@@ -221,7 +249,14 @@ class ChatModelIntegrationTests(ChatModelTests):
assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call"
def test_structured_output(self, model: BaseChatModel) -> None:
result = await model_with_tools.ainvoke(query)
assert isinstance(result, AIMessage)
assert result.tool_calls
tool_call = result.tool_calls[0]
assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call"
async def test_structured_output(self, model: BaseChatModel) -> None:
"""Test to verify structured output with a Pydantic model."""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
@@ -243,22 +278,32 @@ class ChatModelIntegrationTests(ChatModelTests):
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
# Schema
chat = model.with_structured_output(Joke.model_json_schema())
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in chat.stream("Tell me a joke about cats."):
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
async def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
"""Test to verify compatibility with pydantic.v1.BaseModel.
pydantic.v1.BaseModel is available in the pydantic 2 package.
@@ -286,12 +331,18 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in chat.stream("Tell me a joke about cats."):
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
def test_tool_message_histories_string_content(self, model: BaseChatModel) -> None:
async def test_tool_message_histories_string_content(
self, model: BaseChatModel
) -> None:
"""
Test that message histories are compatible with string tool contents
(e.g. OpenAI).
@@ -325,7 +376,10 @@ class ChatModelIntegrationTests(ChatModelTests):
result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)
def test_tool_message_histories_list_content(
result_string_content = await model_with_tools.ainvoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)
async def test_tool_message_histories_list_content(
self,
model: BaseChatModel,
) -> None:
@@ -370,7 +424,10 @@ class ChatModelIntegrationTests(ChatModelTests):
result_list_content = model_with_tools.invoke(messages_list_content)
assert isinstance(result_list_content, AIMessage)
def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
result_list_content = await model_with_tools.ainvoke(messages_list_content)
assert isinstance(result_list_content, AIMessage)
async def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
"""
Test that model can process few-shot examples with tool calls.
"""
@@ -404,8 +461,10 @@ class ChatModelIntegrationTests(ChatModelTests):
]
result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)
result_string_content = await model_with_tools.ainvoke(messages_string_content)
assert isinstance(result_string_content, AIMessage)
def test_image_inputs(self, model: BaseChatModel) -> None:
async def test_image_inputs(self, model: BaseChatModel) -> None:
if not self.supports_image_inputs:
return
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
@@ -420,8 +479,9 @@ class ChatModelIntegrationTests(ChatModelTests):
],
)
model.invoke([message])
await model.ainvoke([message])
def test_anthropic_inputs(self, model: BaseChatModel) -> None:
async def test_anthropic_inputs(self, model: BaseChatModel) -> None:
if not self.supports_anthropic_inputs:
return
@@ -481,8 +541,9 @@ class ChatModelIntegrationTests(ChatModelTests):
),
]
model.bind_tools([color_picker]).invoke(messages)
await model.bind_tools([color_picker]).ainvoke(messages)
def test_tool_message_error_status(self, model: BaseChatModel) -> None:
async def test_tool_message_error_status(self, model: BaseChatModel) -> None:
"""Test that ToolMessage with status='error' can be handled."""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
@@ -509,3 +570,5 @@ class ChatModelIntegrationTests(ChatModelTests):
]
result = model_with_tools.invoke(messages)
assert isinstance(result, AIMessage)
result = await model_with_tools.ainvoke(messages)
assert isinstance(result, AIMessage)