mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
4 Commits
mdrxy/vert
...
eugene/add
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef374c81b3 | ||
|
|
f7ea8bf6c5 | ||
|
|
0e2be55c28 | ||
|
|
c7b68335ce |
@@ -112,7 +112,7 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
The output of the rate limiter.
|
||||
"""
|
||||
|
||||
def _invoke(input: Input) -> Output:
|
||||
def _invoke(input: Input, **kwargs: Any) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
self.acquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
@@ -133,7 +133,7 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def _ainvoke(input: Input) -> Output:
|
||||
async def _ainvoke(input: Input, **kwargs: Any) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
await self.aacquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain_core.messages import (
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import InMemoryRateLimiter
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
@@ -59,38 +60,58 @@ def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
|
||||
|
||||
|
||||
class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_invoke(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("Hello")
|
||||
@pytest.fixture(scope="class")
|
||||
def rate_limiter(self) -> Optional[InMemoryRateLimiter]:
|
||||
"""Override to provide a different rate limiter to your model."""
|
||||
return InMemoryRateLimiter(requests_per_second=1, max_bucket_size=10)
|
||||
|
||||
def test_invoke(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
model_ = rate_limiter | model
|
||||
result = model_.invoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
async def test_ainvoke(self, model: BaseChatModel) -> None:
|
||||
result = await model.ainvoke("Hello")
|
||||
async def test_ainvoke(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
model_ = rate_limiter | model
|
||||
result = await model_.ainvoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
def test_stream(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
num_tokens = 0
|
||||
for token in model.stream("Hello"):
|
||||
model_ = rate_limiter | model
|
||||
for token in model_.stream("Hello"):
|
||||
assert token is not None
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
num_tokens += len(token.content)
|
||||
assert num_tokens > 0
|
||||
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
async def test_astream(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
num_tokens = 0
|
||||
async for token in model.astream("Hello"):
|
||||
model_ = rate_limiter | model
|
||||
async for token in model_.astream("Hello"):
|
||||
assert token is not None
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
num_tokens += len(token.content)
|
||||
assert num_tokens > 0
|
||||
|
||||
def test_batch(self, model: BaseChatModel) -> None:
|
||||
batch_results = model.batch(["Hello", "Hey"])
|
||||
def test_batch(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
model_ = rate_limiter | model
|
||||
batch_results = model_.batch(["Hello", "Hey"])
|
||||
assert batch_results is not None
|
||||
assert isinstance(batch_results, list)
|
||||
assert len(batch_results) == 2
|
||||
@@ -100,8 +121,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
async def test_abatch(self, model: BaseChatModel) -> None:
|
||||
batch_results = await model.abatch(["Hello", "Hey"])
|
||||
async def test_abatch(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
model_ = rate_limiter | model
|
||||
batch_results = await model_.abatch(["Hello", "Hey"])
|
||||
assert batch_results is not None
|
||||
assert isinstance(batch_results, list)
|
||||
assert len(batch_results) == 2
|
||||
@@ -111,22 +135,28 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_conversation(self, model: BaseChatModel) -> None:
|
||||
def test_conversation(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
messages = [
|
||||
HumanMessage("hello"),
|
||||
AIMessage("hello"),
|
||||
HumanMessage("how are you"),
|
||||
]
|
||||
result = model.invoke(messages)
|
||||
model_ = rate_limiter | model
|
||||
result = model_.invoke(messages)
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
def test_usage_metadata(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.returns_usage_metadata:
|
||||
pytest.skip("Not implemented.")
|
||||
result = model.invoke("Hello")
|
||||
model_ = rate_limiter | model
|
||||
result = model_.invoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.usage_metadata is not None
|
||||
@@ -134,11 +164,14 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(result.usage_metadata["output_tokens"], int)
|
||||
assert isinstance(result.usage_metadata["total_tokens"], int)
|
||||
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
def test_usage_metadata_streaming(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.returns_usage_metadata:
|
||||
pytest.skip("Not implemented.")
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model.stream("Hello"):
|
||||
model_ = rate_limiter | model
|
||||
for chunk in model_.stream("Hello"):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
@@ -147,8 +180,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(full.usage_metadata["output_tokens"], int)
|
||||
assert isinstance(full.usage_metadata["total_tokens"], int)
|
||||
|
||||
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||
result = model.invoke("hi", stop=["you"])
|
||||
def test_stop_sequence(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
model_ = rate_limiter | model.bind(stop=["you"])
|
||||
result = model_.invoke("hi")
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
custom_model = self.chat_model_class(
|
||||
@@ -157,39 +193,45 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
result = custom_model.invoke("hi")
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
def test_tool_calling(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([magic_function])
|
||||
model_ = rate_limiter | model.bind_tools([magic_function])
|
||||
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
result = model_.invoke(query)
|
||||
_validate_tool_call_message(result)
|
||||
|
||||
# Test stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
for chunk in model_.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
def test_tool_calling_with_no_arguments(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
model_with_tools = model.bind_tools([magic_function_no_args])
|
||||
model_ = rate_limiter | model.bind_tools([magic_function_no_args])
|
||||
query = "What is the value of magic_function()? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
result = model_.invoke(query)
|
||||
_validate_tool_call_message_no_args(result)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
for chunk in model_.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message_no_args(full)
|
||||
|
||||
def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
|
||||
def test_bind_runnables_as_tools(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
@@ -203,15 +245,18 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
description="Generate a greeting in a particular style of speaking.",
|
||||
)
|
||||
model_with_tools = model.bind_tools([tool_])
|
||||
model_ = rate_limiter | model_with_tools
|
||||
query = "Using the tool, generate a Pirate greeting."
|
||||
result = model_with_tools.invoke(query)
|
||||
result = model_.invoke(query)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.tool_calls
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call["args"].get("answer_style")
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
def test_structured_output(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
"""Test to verify structured output with a Pydantic model."""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
@@ -230,25 +275,29 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
|
||||
# We'll need to do a pass updating the type signatures.
|
||||
chat = model.with_structured_output(Joke) # type: ignore[arg-type]
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
model_ = rate_limiter | chat
|
||||
result = model_.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
for chunk in model_.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(Joke.schema())
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
model_ = rate_limiter | chat
|
||||
result = model_.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
for chunk in model_.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
|
||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||
def test_structured_output_pydantic_2_v1(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
"""Test to verify compatibility with pydantic.v1.BaseModel.
|
||||
|
||||
pydantic.v1.BaseModel is available in the pydantic 2 package.
|
||||
@@ -263,20 +312,20 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
# Pydantic class
|
||||
chat = model.with_structured_output(Joke)
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
model_ = rate_limiter | model.with_structured_output(Joke)
|
||||
result = model_.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
for chunk in model_.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(Joke.schema())
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
model_ = rate_limiter | model.with_structured_output(Joke.schema())
|
||||
result = model_.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
for chunk in model_.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
@@ -284,6 +333,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_tool_message_histories_string_content(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
rate_limiter: InMemoryRateLimiter,
|
||||
) -> None:
|
||||
"""
|
||||
Test that message histories are compatible with string tool contents
|
||||
@@ -291,7 +341,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([my_adder_tool])
|
||||
model_with_tools = rate_limiter | model.bind_tools([my_adder_tool])
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": "1", "b": "2"}
|
||||
|
||||
@@ -321,6 +371,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
rate_limiter: InMemoryRateLimiter,
|
||||
) -> None:
|
||||
"""
|
||||
Test that message histories are compatible with list tool contents
|
||||
@@ -328,7 +379,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([my_adder_tool])
|
||||
model_with_tools = rate_limiter | model.bind_tools([my_adder_tool])
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": 1, "b": 2}
|
||||
|
||||
@@ -363,13 +414,17 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
result_list_content = model_with_tools.invoke(messages_list_content)
|
||||
assert isinstance(result_list_content, AIMessage)
|
||||
|
||||
def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
|
||||
def test_structured_few_shot_examples(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
"""
|
||||
Test that model can process few-shot examples with tool calls.
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any")
|
||||
model_with_tools = rate_limiter | model.bind_tools(
|
||||
[my_adder_tool], tool_choice="any"
|
||||
)
|
||||
function_name = "my_adder_tool"
|
||||
function_args = {"a": 1, "b": 2}
|
||||
function_result = json.dumps({"result": 3})
|
||||
@@ -398,7 +453,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
result_string_content = model_with_tools.invoke(messages_string_content)
|
||||
assert isinstance(result_string_content, AIMessage)
|
||||
|
||||
def test_image_inputs(self, model: BaseChatModel) -> None:
|
||||
def test_image_inputs(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.supports_image_inputs:
|
||||
return
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
@@ -414,7 +471,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
)
|
||||
model.invoke([message])
|
||||
|
||||
def test_anthropic_inputs(self, model: BaseChatModel) -> None:
|
||||
def test_anthropic_inputs(
|
||||
self, model: BaseChatModel, rate_limiter: InMemoryRateLimiter
|
||||
) -> None:
|
||||
if not self.supports_anthropic_inputs:
|
||||
return
|
||||
|
||||
@@ -472,4 +531,4 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
]
|
||||
),
|
||||
]
|
||||
model.bind_tools([color_picker]).invoke(messages)
|
||||
(rate_limiter | model.bind_tools([color_picker])).invoke(messages)
|
||||
|
||||
Reference in New Issue
Block a user