mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
Add Batch Size kwarg to the llm start callback (#13483)
So you can more easily use the token counts directly from the API endpoint for batch size of 1
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
"""Test base chat model."""
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I am a test user."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages_2() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I not a test user."),
|
||||
]
|
||||
|
||||
|
||||
def test_batch_size(messages: list, messages_2: list) -> None:
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert len(cb.traced_runs) == 2
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
llm.invoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
list(llm.stream(messages))
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 2
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
async for _ in llm.astream(messages):
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
@@ -1,3 +1,4 @@
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
@@ -17,3 +18,60 @@ async def test_abatch() -> None:
|
||||
|
||||
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||
assert output == ["foo"] * 3
|
||||
|
||||
|
||||
def test_batch_size() -> None:
|
||||
llm = FakeListLLM(responses=["foo"] * 3)
|
||||
with collect_runs() as cb:
|
||||
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 3
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
llm.batch(["foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
llm.invoke("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
list(llm.stream("foo"))
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"] * 1)
|
||||
with collect_runs() as cb:
|
||||
llm.predict("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size() -> None:
|
||||
llm = FakeListLLM(responses=["foo"] * 3)
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 3
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch(["foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
async for _ in llm.astream("foo"):
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user