mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
community: support usage_metadata for litellm streaming calls (#30683)
Support "usage_metadata" for LiteLLM streaming calls. This is a follow-up to https://github.com/langchain-ai/langchain/pull/30625, which tackled non-streaming calls. If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
parent
5ffcd01c41
commit
34ddfba76b
@ -452,6 +452,7 @@ class ChatLiteLLM(BaseChatModel):
|
|||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
|
added_model_name = False
|
||||||
for chunk in self.completion_with_retry(
|
for chunk in self.completion_with_retry(
|
||||||
messages=message_dicts, run_manager=run_manager, **params
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
@ -460,7 +461,15 @@ class ChatLiteLLM(BaseChatModel):
|
|||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
if isinstance(chunk, AIMessageChunk):
|
||||||
|
if not added_model_name:
|
||||||
|
chunk.response_metadata = {
|
||||||
|
"model_name": self.model_name or self.model
|
||||||
|
}
|
||||||
|
added_model_name = True
|
||||||
|
chunk.usage_metadata = _create_usage_metadata(usage)
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -478,6 +487,7 @@ class ChatLiteLLM(BaseChatModel):
|
|||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
|
added_model_name = False
|
||||||
async for chunk in await acompletion_with_retry(
|
async for chunk in await acompletion_with_retry(
|
||||||
self, messages=message_dicts, run_manager=run_manager, **params
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
@ -486,7 +496,15 @@ class ChatLiteLLM(BaseChatModel):
|
|||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
if isinstance(chunk, AIMessageChunk):
|
||||||
|
if not added_model_name:
|
||||||
|
chunk.response_metadata = {
|
||||||
|
"model_name": self.model_name or self.model
|
||||||
|
}
|
||||||
|
added_model_name = True
|
||||||
|
chunk.usage_metadata = _create_usage_metadata(usage)
|
||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
|
||||||
@ -16,8 +15,8 @@ class TestLiteLLMStandard(ChatModelIntegrationTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": "ollama/mistral"}
|
return {
|
||||||
|
"model": "ollama/mistral",
|
||||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
# Needed to get the usage object when streaming. See https://docs.litellm.ai/docs/completion/usage#streaming-usage
|
||||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
"model_kwargs": {"stream_options": {"include_usage": True}},
|
||||||
super().test_usage_metadata_streaming(model)
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user