mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community: Fix ChatLiteLLMRouter runtime issues (#28163)
**Description:** Fix ChatLiteLLMRouter ctor validation and model_name parameter **Issue:** #19356, #27455, #28077 **Twitter handle:** @bburgin_0
This commit is contained in:
parent
234d49653a
commit
27a9056725
@ -63,9 +63,9 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"model_name\": \"gpt-4\",\n",
|
" \"model_name\": \"gpt-35-turbo\",\n",
|
||||||
" \"litellm_params\": {\n",
|
" \"litellm_params\": {\n",
|
||||||
" \"model\": \"azure/gpt-4-1106-preview\",\n",
|
" \"model\": \"azure/gpt-35-turbo\",\n",
|
||||||
" \"api_key\": \"<your-api-key>\",\n",
|
" \"api_key\": \"<your-api-key>\",\n",
|
||||||
" \"api_version\": \"2023-05-15\",\n",
|
" \"api_version\": \"2023-05-15\",\n",
|
||||||
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
|
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
|
||||||
@ -73,7 +73,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"litellm_router = Router(model_list=model_list)\n",
|
"litellm_router = Router(model_list=model_list)\n",
|
||||||
"chat = ChatLiteLLMRouter(router=litellm_router)"
|
"chat = ChatLiteLLMRouter(router=litellm_router, model_name=\"gpt-35-turbo\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -177,6 +177,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"chat = ChatLiteLLMRouter(\n",
|
"chat = ChatLiteLLMRouter(\n",
|
||||||
" router=litellm_router,\n",
|
" router=litellm_router,\n",
|
||||||
|
" model_name=\"gpt-35-turbo\",\n",
|
||||||
" streaming=True,\n",
|
" streaming=True,\n",
|
||||||
" verbose=True,\n",
|
" verbose=True,\n",
|
||||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
@ -209,7 +210,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.13"
|
"version": "3.11.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,13 +1,6 @@
|
|||||||
"""LiteLLM Router as LangChain Model."""
|
"""LiteLLM Router as LangChain Model."""
|
||||||
|
|
||||||
from typing import (
|
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -17,15 +10,8 @@ from langchain_core.language_models.chat_models import (
|
|||||||
agenerate_from_stream,
|
agenerate_from_stream,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||||
AIMessageChunk,
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
BaseMessage,
|
|
||||||
)
|
|
||||||
from langchain_core.outputs import (
|
|
||||||
ChatGeneration,
|
|
||||||
ChatGenerationChunk,
|
|
||||||
ChatResult,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_community.chat_models.litellm import (
|
from langchain_community.chat_models.litellm import (
|
||||||
ChatLiteLLM,
|
ChatLiteLLM,
|
||||||
@ -33,8 +19,8 @@ from langchain_community.chat_models.litellm import (
|
|||||||
_convert_dict_to_message,
|
_convert_dict_to_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
token_usage_key_name = "token_usage"
|
token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
|
||||||
model_extra_key_name = "model_extra"
|
model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
|
||||||
|
|
||||||
|
|
||||||
def get_llm_output(usage: Any, **params: Any) -> dict:
|
def get_llm_output(usage: Any, **params: Any) -> dict:
|
||||||
@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
|
|
||||||
def __init__(self, *, router: Any, **kwargs: Any) -> None:
|
def __init__(self, *, router: Any, **kwargs: Any) -> None:
|
||||||
"""Construct Chat LiteLLM Router."""
|
"""Construct Chat LiteLLM Router."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(router=router, **kwargs) # type: ignore
|
||||||
self.router = router
|
self.router = router
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "LiteLLMRouter"
|
return "LiteLLMRouter"
|
||||||
|
|
||||||
def _set_model_for_completion(self) -> None:
|
|
||||||
# use first model name (aka: model group),
|
|
||||||
# since we can only pass one to the router completion functions
|
|
||||||
self.model = self.router.model_list[0]["model_name"]
|
|
||||||
|
|
||||||
def _prepare_params_for_router(self, params: Any) -> None:
|
def _prepare_params_for_router(self, params: Any) -> None:
|
||||||
params["model"] = self.model
|
|
||||||
|
|
||||||
# allow the router to set api_base based on its model choice
|
# allow the router to set api_base based on its model choice
|
||||||
api_base_key_name = "api_base"
|
api_base_key_name = "api_base"
|
||||||
if api_base_key_name in params and params[api_base_key_name] is None:
|
if api_base_key_name in params and params[api_base_key_name] is None:
|
||||||
@ -79,6 +58,22 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
# add metadata so router can fill it below
|
# add metadata so router can fill it below
|
||||||
params.setdefault("metadata", {})
|
params.setdefault("metadata", {})
|
||||||
|
|
||||||
|
def set_default_model(self, model_name: str) -> None:
|
||||||
|
"""Set the default model to use for completion calls.
|
||||||
|
|
||||||
|
Sets `self.model` to `model_name` if it is in the litellm router's
|
||||||
|
(`self.router`) model list. This provides the default model to use
|
||||||
|
for completion calls if no `model` kwarg is provided.
|
||||||
|
"""
|
||||||
|
model_list = self.router.model_list
|
||||||
|
if not model_list:
|
||||||
|
raise ValueError("model_list is None or empty.")
|
||||||
|
for entry in model_list:
|
||||||
|
if entry["model_name"] == model_name:
|
||||||
|
self.model = model_name
|
||||||
|
return
|
||||||
|
raise ValueError(f"Model {model_name} not found in model_list.")
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -96,7 +91,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
self._set_model_for_completion()
|
|
||||||
self._prepare_params_for_router(params)
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
response = self.router.completion(
|
response = self.router.completion(
|
||||||
@ -115,7 +109,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
self._set_model_for_completion()
|
|
||||||
self._prepare_params_for_router(params)
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
for chunk in self.router.completion(messages=message_dicts, **params):
|
for chunk in self.router.completion(messages=message_dicts, **params):
|
||||||
@ -139,7 +132,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
self._set_model_for_completion()
|
|
||||||
self._prepare_params_for_router(params)
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
async for chunk in await self.router.acompletion(
|
async for chunk in await self.router.acompletion(
|
||||||
@ -174,7 +166,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
self._set_model_for_completion()
|
|
||||||
self._prepare_params_for_router(params)
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
response = await self.router.acompletion(
|
response = await self.router.acompletion(
|
||||||
@ -196,14 +187,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
|||||||
token_usage = output["token_usage"]
|
token_usage = output["token_usage"]
|
||||||
if token_usage is not None:
|
if token_usage is not None:
|
||||||
# get dict from LiteLLM Usage class
|
# get dict from LiteLLM Usage class
|
||||||
for k, v in token_usage.dict().items():
|
for k, v in token_usage.model_dump().items():
|
||||||
if k in overall_token_usage:
|
if k in overall_token_usage and overall_token_usage[k] is not None:
|
||||||
overall_token_usage[k] += v
|
overall_token_usage[k] += v
|
||||||
else:
|
else:
|
||||||
overall_token_usage[k] = v
|
overall_token_usage[k] = v
|
||||||
if system_fingerprint is None:
|
if system_fingerprint is None:
|
||||||
system_fingerprint = output.get("system_fingerprint")
|
system_fingerprint = output.get("system_fingerprint")
|
||||||
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
if system_fingerprint:
|
if system_fingerprint:
|
||||||
combined["system_fingerprint"] = system_fingerprint
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
return combined
|
return combined
|
||||||
|
@ -1,8 +1,20 @@
|
|||||||
"""Test LiteLLM Router API wrapper."""
|
"""Test LiteLLM Router API wrapper."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
@ -11,7 +23,8 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
|||||||
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
|
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
model_group = "gpt-4"
|
model_group_gpt4 = "gpt-4"
|
||||||
|
model_group_to_test = "gpt-35-turbo"
|
||||||
fake_model_prefix = "azure/fake-deployment-name-"
|
fake_model_prefix = "azure/fake-deployment-name-"
|
||||||
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
|
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
|
||||||
fake_api_key = "fakekeyvalue"
|
fake_api_key = "fakekeyvalue"
|
||||||
@ -23,7 +36,7 @@ token_usage_key_name = "token_usage"
|
|||||||
|
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
"model_name": model_group,
|
"model_name": model_group_gpt4,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": fake_models_names[0],
|
"model": fake_models_names[0],
|
||||||
"api_key": fake_api_key,
|
"api_key": fake_api_key,
|
||||||
@ -32,7 +45,7 @@ model_list = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": model_group,
|
"model_name": model_group_to_test,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": fake_models_names[1],
|
"model": fake_models_names[1],
|
||||||
"api_key": fake_api_key,
|
"api_key": fake_api_key,
|
||||||
@ -43,6 +56,39 @@ model_list = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# from https://stackoverflow.com/a/78573267
|
||||||
|
def aiter_to_iter(it: AsyncIterator) -> Generator:
|
||||||
|
"Convert an async iterator into a regular (sync) iterator."
|
||||||
|
q_in: queue.SimpleQueue = queue.SimpleQueue()
|
||||||
|
q_out: queue.SimpleQueue = queue.SimpleQueue()
|
||||||
|
|
||||||
|
async def threadmain() -> None:
|
||||||
|
try:
|
||||||
|
# Wait until the sync generator requests an item before continuing
|
||||||
|
while q_in.get():
|
||||||
|
q_out.put((True, await it.__anext__()))
|
||||||
|
except StopAsyncIteration:
|
||||||
|
q_out.put((False, None))
|
||||||
|
except BaseException as ex:
|
||||||
|
q_out.put((False, ex))
|
||||||
|
|
||||||
|
thread = threading.Thread(target=asyncio.run, args=(threadmain(),), daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
q_in.put(True)
|
||||||
|
cont, result = q_out.get()
|
||||||
|
if cont:
|
||||||
|
yield result
|
||||||
|
elif result is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise result
|
||||||
|
finally:
|
||||||
|
q_in.put(False)
|
||||||
|
|
||||||
|
|
||||||
class FakeCompletion:
|
class FakeCompletion:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.seen_inputs: List[Any] = []
|
self.seen_inputs: List[Any] = []
|
||||||
@ -55,13 +101,6 @@ class FakeCompletion:
|
|||||||
choices = cast(List[Dict[str, Any]], result["choices"])
|
choices = cast(List[Dict[str, Any]], result["choices"])
|
||||||
return result, choices
|
return result, choices
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_next_result(
|
|
||||||
agen: AsyncGenerator[Dict[str, Any], None],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
coroutine = cast(Coroutine, agen.__anext__())
|
|
||||||
return asyncio.run(coroutine)
|
|
||||||
|
|
||||||
async def _get_fake_results_agenerator(
|
async def _get_fake_results_agenerator(
|
||||||
self, **kwargs: Any
|
self, **kwargs: Any
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
@ -76,7 +115,7 @@ class FakeCompletion:
|
|||||||
],
|
],
|
||||||
"created": 0,
|
"created": 0,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": model_group,
|
"model": model_group_to_test,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
}
|
}
|
||||||
if kwargs["stream"]:
|
if kwargs["stream"]:
|
||||||
@ -115,17 +154,18 @@ class FakeCompletion:
|
|||||||
|
|
||||||
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
|
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
|
||||||
agen = self._get_fake_results_agenerator(**kwargs)
|
agen = self._get_fake_results_agenerator(**kwargs)
|
||||||
|
synchronous_iter = aiter_to_iter(agen)
|
||||||
if kwargs["stream"]:
|
if kwargs["stream"]:
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Dict[str, Any]] = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
results.append(self._get_next_result(agen))
|
results.append(synchronous_iter.__next__())
|
||||||
except StopAsyncIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
# there is only one result for non-streaming
|
# there is only one result for non-streaming
|
||||||
return self._get_next_result(agen)
|
return synchronous_iter.__next__()
|
||||||
|
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self, **kwargs: Any
|
self, **kwargs: Any
|
||||||
@ -142,7 +182,7 @@ class FakeCompletion:
|
|||||||
for kwargs in self.seen_inputs:
|
for kwargs in self.seen_inputs:
|
||||||
metadata = kwargs["metadata"]
|
metadata = kwargs["metadata"]
|
||||||
|
|
||||||
assert metadata["model_group"] == model_group
|
assert metadata["model_group"] == model_group_to_test
|
||||||
|
|
||||||
# LiteLLM router chooses one model name from the model_list
|
# LiteLLM router chooses one model name from the model_list
|
||||||
assert kwargs["model"] in fake_models_names
|
assert kwargs["model"] in fake_models_names
|
||||||
@ -172,17 +212,16 @@ def litellm_router() -> Any:
|
|||||||
"""LiteLLM router for testing."""
|
"""LiteLLM router for testing."""
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
return Router(
|
return Router(model_list=model_list)
|
||||||
model_list=model_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
def test_litellm_router_call(
|
def test_litellm_router_call(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test valid call to LiteLLM Router."""
|
"""Test valid call to LiteLLM Router."""
|
||||||
chat = ChatLiteLLMRouter(router=litellm_router)
|
chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
response = chat.invoke([message])
|
response = chat.invoke([message])
|
||||||
@ -195,13 +234,12 @@ def test_litellm_router_call(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
def test_litellm_router_generate(
|
def test_litellm_router_generate(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test generate method of LiteLLM Router."""
|
"""Test generate method of LiteLLM Router."""
|
||||||
from litellm import Usage
|
chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
|
||||||
|
|
||||||
chat = ChatLiteLLMRouter(router=litellm_router)
|
|
||||||
chat_messages: List[List[BaseMessage]] = [
|
chat_messages: List[List[BaseMessage]] = [
|
||||||
[HumanMessage(content="How many toes do dogs have?")]
|
[HumanMessage(content="How many toes do dogs have?")]
|
||||||
]
|
]
|
||||||
@ -219,18 +257,25 @@ def test_litellm_router_generate(
|
|||||||
assert generation.message.content == fake_answer
|
assert generation.message.content == fake_answer
|
||||||
assert chat_messages == messages_copy
|
assert chat_messages == messages_copy
|
||||||
assert result.llm_output is not None
|
assert result.llm_output is not None
|
||||||
assert result.llm_output[token_usage_key_name] == Usage(
|
assert result.llm_output[token_usage_key_name] == {
|
||||||
completion_tokens=1, prompt_tokens=2, total_tokens=3
|
"completion_tokens": 1,
|
||||||
)
|
"completion_tokens_details": None,
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"prompt_tokens_details": None,
|
||||||
|
"total_tokens": 3,
|
||||||
|
}
|
||||||
fake_completion.check_inputs(expected_num_calls=1)
|
fake_completion.check_inputs(expected_num_calls=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
def test_litellm_router_streaming(
|
def test_litellm_router_streaming(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming tokens from LiteLLM Router."""
|
"""Test streaming tokens from LiteLLM Router."""
|
||||||
chat = ChatLiteLLMRouter(router=litellm_router, streaming=True)
|
chat = ChatLiteLLMRouter(
|
||||||
|
router=litellm_router, model_name=model_group_to_test, streaming=True
|
||||||
|
)
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
response = chat.invoke([message])
|
response = chat.invoke([message])
|
||||||
@ -243,6 +288,7 @@ def test_litellm_router_streaming(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
def test_litellm_router_streaming_callback(
|
def test_litellm_router_streaming_callback(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -250,6 +296,7 @@ def test_litellm_router_streaming_callback(
|
|||||||
callback_handler = FakeCallbackHandler()
|
callback_handler = FakeCallbackHandler()
|
||||||
chat = ChatLiteLLMRouter(
|
chat = ChatLiteLLMRouter(
|
||||||
router=litellm_router,
|
router=litellm_router,
|
||||||
|
model_name=model_group_to_test,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
callbacks=[callback_handler],
|
callbacks=[callback_handler],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -267,13 +314,12 @@ def test_litellm_router_streaming_callback(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
async def test_async_litellm_router(
|
async def test_async_litellm_router(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test async generation."""
|
"""Test async generation."""
|
||||||
from litellm import Usage
|
chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
|
||||||
|
|
||||||
chat = ChatLiteLLMRouter(router=litellm_router)
|
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
response = await chat.agenerate([[message], [message]])
|
response = await chat.agenerate([[message], [message]])
|
||||||
@ -288,13 +334,18 @@ async def test_async_litellm_router(
|
|||||||
assert generation.message.content == generation.text
|
assert generation.message.content == generation.text
|
||||||
assert generation.message.content == fake_answer
|
assert generation.message.content == fake_answer
|
||||||
assert response.llm_output is not None
|
assert response.llm_output is not None
|
||||||
assert response.llm_output[token_usage_key_name] == Usage(
|
assert response.llm_output[token_usage_key_name] == {
|
||||||
completion_tokens=2, prompt_tokens=4, total_tokens=6
|
"completion_tokens": 2,
|
||||||
)
|
"completion_tokens_details": None,
|
||||||
|
"prompt_tokens": 4,
|
||||||
|
"prompt_tokens_details": None,
|
||||||
|
"total_tokens": 6,
|
||||||
|
}
|
||||||
fake_completion.check_inputs(expected_num_calls=2)
|
fake_completion.check_inputs(expected_num_calls=2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.enable_socket
|
||||||
async def test_async_litellm_router_streaming(
|
async def test_async_litellm_router_streaming(
|
||||||
fake_completion: FakeCompletion, litellm_router: Any
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -302,6 +353,7 @@ async def test_async_litellm_router_streaming(
|
|||||||
callback_handler = FakeCallbackHandler()
|
callback_handler = FakeCallbackHandler()
|
||||||
chat = ChatLiteLLMRouter(
|
chat = ChatLiteLLMRouter(
|
||||||
router=litellm_router,
|
router=litellm_router,
|
||||||
|
model_name=model_group_to_test,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
callbacks=[callback_handler],
|
callbacks=[callback_handler],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user