community: Fix ChatLiteLLMRouter runtime issues (#28163)

**Description:** Fix ChatLiteLLMRouter ctor validation and model_name
parameter
**Issue:** #19356, #27455, #28077
**Twitter handle:** @bburgin_0
This commit is contained in:
Brian Burgin 2024-12-16 17:17:39 -06:00 committed by GitHub
parent 234d49653a
commit 27a9056725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 115 additions and 71 deletions

View File

@ -63,9 +63,9 @@
" },\n", " },\n",
" },\n", " },\n",
" {\n", " {\n",
" \"model_name\": \"gpt-4\",\n", " \"model_name\": \"gpt-35-turbo\",\n",
" \"litellm_params\": {\n", " \"litellm_params\": {\n",
" \"model\": \"azure/gpt-4-1106-preview\",\n", " \"model\": \"azure/gpt-35-turbo\",\n",
" \"api_key\": \"<your-api-key>\",\n", " \"api_key\": \"<your-api-key>\",\n",
" \"api_version\": \"2023-05-15\",\n", " \"api_version\": \"2023-05-15\",\n",
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n", " \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
@ -73,7 +73,7 @@
" },\n", " },\n",
"]\n", "]\n",
"litellm_router = Router(model_list=model_list)\n", "litellm_router = Router(model_list=model_list)\n",
"chat = ChatLiteLLMRouter(router=litellm_router)" "chat = ChatLiteLLMRouter(router=litellm_router, model_name=\"gpt-35-turbo\")"
] ]
}, },
{ {
@ -177,6 +177,7 @@
"source": [ "source": [
"chat = ChatLiteLLMRouter(\n", "chat = ChatLiteLLMRouter(\n",
" router=litellm_router,\n", " router=litellm_router,\n",
" model_name=\"gpt-35-turbo\",\n",
" streaming=True,\n", " streaming=True,\n",
" verbose=True,\n", " verbose=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
@ -209,7 +210,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.13" "version": "3.11.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,13 +1,6 @@
"""LiteLLM Router as LangChain Model.""" """LiteLLM Router as LangChain Model."""
from typing import ( from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
Any,
AsyncIterator,
Iterator,
List,
Mapping,
Optional,
)
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -17,15 +10,8 @@ from langchain_core.language_models.chat_models import (
agenerate_from_stream, agenerate_from_stream,
generate_from_stream, generate_from_stream,
) )
from langchain_core.messages import ( from langchain_core.messages import AIMessageChunk, BaseMessage
AIMessageChunk, from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_community.chat_models.litellm import ( from langchain_community.chat_models.litellm import (
ChatLiteLLM, ChatLiteLLM,
@ -33,8 +19,8 @@ from langchain_community.chat_models.litellm import (
_convert_dict_to_message, _convert_dict_to_message,
) )
token_usage_key_name = "token_usage" token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
model_extra_key_name = "model_extra" model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
def get_llm_output(usage: Any, **params: Any) -> dict: def get_llm_output(usage: Any, **params: Any) -> dict:
@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
def __init__(self, *, router: Any, **kwargs: Any) -> None: def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router.""" """Construct Chat LiteLLM Router."""
super().__init__(**kwargs) super().__init__(router=router, **kwargs) # type: ignore
self.router = router self.router = router
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "LiteLLMRouter" return "LiteLLMRouter"
def _set_model_for_completion(self) -> None:
# use first model name (aka: model group),
# since we can only pass one to the router completion functions
self.model = self.router.model_list[0]["model_name"]
def _prepare_params_for_router(self, params: Any) -> None: def _prepare_params_for_router(self, params: Any) -> None:
params["model"] = self.model
# allow the router to set api_base based on its model choice # allow the router to set api_base based on its model choice
api_base_key_name = "api_base" api_base_key_name = "api_base"
if api_base_key_name in params and params[api_base_key_name] is None: if api_base_key_name in params and params[api_base_key_name] is None:
@ -79,6 +58,22 @@ class ChatLiteLLMRouter(ChatLiteLLM):
# add metadata so router can fill it below # add metadata so router can fill it below
params.setdefault("metadata", {}) params.setdefault("metadata", {})
def set_default_model(self, model_name: str) -> None:
"""Set the default model to use for completion calls.
Sets `self.model` to `model_name` if it is in the litellm router's
(`self.router`) model list. This provides the default model to use
for completion calls if no `model` kwarg is provided.
"""
model_list = self.router.model_list
if not model_list:
raise ValueError("model_list is None or empty.")
for entry in model_list:
if entry["model_name"] == model_name:
self.model = model_name
return
raise ValueError(f"Model {model_name} not found in model_list.")
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -96,7 +91,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params) self._prepare_params_for_router(params)
response = self.router.completion( response = self.router.completion(
@ -115,7 +109,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params) self._prepare_params_for_router(params)
for chunk in self.router.completion(messages=message_dicts, **params): for chunk in self.router.completion(messages=message_dicts, **params):
@ -139,7 +132,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params) self._prepare_params_for_router(params)
async for chunk in await self.router.acompletion( async for chunk in await self.router.acompletion(
@ -174,7 +166,6 @@ class ChatLiteLLMRouter(ChatLiteLLM):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params) self._prepare_params_for_router(params)
response = await self.router.acompletion( response = await self.router.acompletion(
@ -196,14 +187,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):
token_usage = output["token_usage"] token_usage = output["token_usage"]
if token_usage is not None: if token_usage is not None:
# get dict from LiteLLM Usage class # get dict from LiteLLM Usage class
for k, v in token_usage.dict().items(): for k, v in token_usage.model_dump().items():
if k in overall_token_usage: if k in overall_token_usage and overall_token_usage[k] is not None:
overall_token_usage[k] += v overall_token_usage[k] += v
else: else:
overall_token_usage[k] = v overall_token_usage[k] = v
if system_fingerprint is None: if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint") system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name} combined = {"token_usage": overall_token_usage, "model_name": self.model}
if system_fingerprint: if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint combined["system_fingerprint"] = system_fingerprint
return combined return combined

View File

@ -1,8 +1,20 @@
"""Test LiteLLM Router API wrapper.""" """Test LiteLLM Router API wrapper."""
import asyncio import asyncio
import queue
import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generator,
List,
Tuple,
Union,
cast,
)
import pytest import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
@ -11,7 +23,8 @@ from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
model_group = "gpt-4" model_group_gpt4 = "gpt-4"
model_group_to_test = "gpt-35-turbo"
fake_model_prefix = "azure/fake-deployment-name-" fake_model_prefix = "azure/fake-deployment-name-"
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]] fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
fake_api_key = "fakekeyvalue" fake_api_key = "fakekeyvalue"
@ -23,7 +36,7 @@ token_usage_key_name = "token_usage"
model_list = [ model_list = [
{ {
"model_name": model_group, "model_name": model_group_gpt4,
"litellm_params": { "litellm_params": {
"model": fake_models_names[0], "model": fake_models_names[0],
"api_key": fake_api_key, "api_key": fake_api_key,
@ -32,7 +45,7 @@ model_list = [
}, },
}, },
{ {
"model_name": model_group, "model_name": model_group_to_test,
"litellm_params": { "litellm_params": {
"model": fake_models_names[1], "model": fake_models_names[1],
"api_key": fake_api_key, "api_key": fake_api_key,
@ -43,6 +56,39 @@ model_list = [
] ]
# from https://stackoverflow.com/a/78573267
def aiter_to_iter(it: AsyncIterator) -> Generator:
"Convert an async iterator into a regular (sync) iterator."
q_in: queue.SimpleQueue = queue.SimpleQueue()
q_out: queue.SimpleQueue = queue.SimpleQueue()
async def threadmain() -> None:
try:
# Wait until the sync generator requests an item before continuing
while q_in.get():
q_out.put((True, await it.__anext__()))
except StopAsyncIteration:
q_out.put((False, None))
except BaseException as ex:
q_out.put((False, ex))
thread = threading.Thread(target=asyncio.run, args=(threadmain(),), daemon=True)
thread.start()
try:
while True:
q_in.put(True)
cont, result = q_out.get()
if cont:
yield result
elif result is None:
break
else:
raise result
finally:
q_in.put(False)
class FakeCompletion: class FakeCompletion:
def __init__(self) -> None: def __init__(self) -> None:
self.seen_inputs: List[Any] = [] self.seen_inputs: List[Any] = []
@ -55,13 +101,6 @@ class FakeCompletion:
choices = cast(List[Dict[str, Any]], result["choices"]) choices = cast(List[Dict[str, Any]], result["choices"])
return result, choices return result, choices
@staticmethod
def _get_next_result(
agen: AsyncGenerator[Dict[str, Any], None],
) -> Dict[str, Any]:
coroutine = cast(Coroutine, agen.__anext__())
return asyncio.run(coroutine)
async def _get_fake_results_agenerator( async def _get_fake_results_agenerator(
self, **kwargs: Any self, **kwargs: Any
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
@ -76,7 +115,7 @@ class FakeCompletion:
], ],
"created": 0, "created": 0,
"id": "", "id": "",
"model": model_group, "model": model_group_to_test,
"object": "chat.completion", "object": "chat.completion",
} }
if kwargs["stream"]: if kwargs["stream"]:
@ -115,17 +154,18 @@ class FakeCompletion:
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]: def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
agen = self._get_fake_results_agenerator(**kwargs) agen = self._get_fake_results_agenerator(**kwargs)
synchronous_iter = aiter_to_iter(agen)
if kwargs["stream"]: if kwargs["stream"]:
results: List[Dict[str, Any]] = [] results: List[Dict[str, Any]] = []
while True: while True:
try: try:
results.append(self._get_next_result(agen)) results.append(synchronous_iter.__next__())
except StopAsyncIteration: except StopIteration:
break break
return results return results
else: else:
# there is only one result for non-streaming # there is only one result for non-streaming
return self._get_next_result(agen) return synchronous_iter.__next__()
async def acompletion( async def acompletion(
self, **kwargs: Any self, **kwargs: Any
@ -142,7 +182,7 @@ class FakeCompletion:
for kwargs in self.seen_inputs: for kwargs in self.seen_inputs:
metadata = kwargs["metadata"] metadata = kwargs["metadata"]
assert metadata["model_group"] == model_group assert metadata["model_group"] == model_group_to_test
# LiteLLM router chooses one model name from the model_list # LiteLLM router chooses one model name from the model_list
assert kwargs["model"] in fake_models_names assert kwargs["model"] in fake_models_names
@ -172,17 +212,16 @@ def litellm_router() -> Any:
"""LiteLLM router for testing.""" """LiteLLM router for testing."""
from litellm import Router from litellm import Router
return Router( return Router(model_list=model_list)
model_list=model_list,
)
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
def test_litellm_router_call( def test_litellm_router_call(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
"""Test valid call to LiteLLM Router.""" """Test valid call to LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router) chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat.invoke([message]) response = chat.invoke([message])
@ -195,13 +234,12 @@ def test_litellm_router_call(
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
def test_litellm_router_generate( def test_litellm_router_generate(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
"""Test generate method of LiteLLM Router.""" """Test generate method of LiteLLM Router."""
from litellm import Usage chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
chat = ChatLiteLLMRouter(router=litellm_router)
chat_messages: List[List[BaseMessage]] = [ chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")] [HumanMessage(content="How many toes do dogs have?")]
] ]
@ -219,18 +257,25 @@ def test_litellm_router_generate(
assert generation.message.content == fake_answer assert generation.message.content == fake_answer
assert chat_messages == messages_copy assert chat_messages == messages_copy
assert result.llm_output is not None assert result.llm_output is not None
assert result.llm_output[token_usage_key_name] == Usage( assert result.llm_output[token_usage_key_name] == {
completion_tokens=1, prompt_tokens=2, total_tokens=3 "completion_tokens": 1,
) "completion_tokens_details": None,
"prompt_tokens": 2,
"prompt_tokens_details": None,
"total_tokens": 3,
}
fake_completion.check_inputs(expected_num_calls=1) fake_completion.check_inputs(expected_num_calls=1)
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
def test_litellm_router_streaming( def test_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
"""Test streaming tokens from LiteLLM Router.""" """Test streaming tokens from LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router, streaming=True) chat = ChatLiteLLMRouter(
router=litellm_router, model_name=model_group_to_test, streaming=True
)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat.invoke([message]) response = chat.invoke([message])
@ -243,6 +288,7 @@ def test_litellm_router_streaming(
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
def test_litellm_router_streaming_callback( def test_litellm_router_streaming_callback(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
@ -250,6 +296,7 @@ def test_litellm_router_streaming_callback(
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter( chat = ChatLiteLLMRouter(
router=litellm_router, router=litellm_router,
model_name=model_group_to_test,
streaming=True, streaming=True,
callbacks=[callback_handler], callbacks=[callback_handler],
verbose=True, verbose=True,
@ -267,13 +314,12 @@ def test_litellm_router_streaming_callback(
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
async def test_async_litellm_router( async def test_async_litellm_router(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
"""Test async generation.""" """Test async generation."""
from litellm import Usage chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test)
chat = ChatLiteLLMRouter(router=litellm_router)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]]) response = await chat.agenerate([[message], [message]])
@ -288,13 +334,18 @@ async def test_async_litellm_router(
assert generation.message.content == generation.text assert generation.message.content == generation.text
assert generation.message.content == fake_answer assert generation.message.content == fake_answer
assert response.llm_output is not None assert response.llm_output is not None
assert response.llm_output[token_usage_key_name] == Usage( assert response.llm_output[token_usage_key_name] == {
completion_tokens=2, prompt_tokens=4, total_tokens=6 "completion_tokens": 2,
) "completion_tokens_details": None,
"prompt_tokens": 4,
"prompt_tokens_details": None,
"total_tokens": 6,
}
fake_completion.check_inputs(expected_num_calls=2) fake_completion.check_inputs(expected_num_calls=2)
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.enable_socket
async def test_async_litellm_router_streaming( async def test_async_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any fake_completion: FakeCompletion, litellm_router: Any
) -> None: ) -> None:
@ -302,6 +353,7 @@ async def test_async_litellm_router_streaming(
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter( chat = ChatLiteLLMRouter(
router=litellm_router, router=litellm_router,
model_name=model_group_to_test,
streaming=True, streaming=True,
callbacks=[callback_handler], callbacks=[callback_handler],
verbose=True, verbose=True,