community[minor]: Add LiteLLM Router Integration (#15588)

community:

  - **Description:**
- Add new ChatLiteLLMRouter class that allows a client to use a LiteLLM
Router as a LangChain chat model.
- Note: The existing ChatLiteLLM integration did not cover the LiteLLM
Router class.
    - Add tests and Jupyter notebook.
  - **Issue:** None
  - **Dependencies:** Relies on existing ChatLiteLLM integration
  - **Twitter handle:** @bburgin_0

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Brian Burgin
2024-01-25 13:03:05 -06:00
committed by GitHub
parent 35e60728b7
commit 148347e858
5 changed files with 768 additions and 0 deletions

View File

@@ -40,6 +40,7 @@ from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGate
from langchain_community.chat_models.jinachat import JinaChat
from langchain_community.chat_models.konko import ChatKonko
from langchain_community.chat_models.litellm import ChatLiteLLM
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
from langchain_community.chat_models.minimax import MiniMaxChat
from langchain_community.chat_models.mlflow import ChatMlflow
@@ -78,6 +79,7 @@ __all__ = [
"MiniMaxChat",
"ChatAnyscale",
"ChatLiteLLM",
"ChatLiteLLMRouter",
"ErnieBotChat",
"ChatJavelinAIGateway",
"ChatKonko",

View File

@@ -0,0 +1,221 @@
"""LiteLLM Router as LangChain Model."""
from typing import (
Any,
AsyncIterator,
Iterator,
List,
Mapping,
Optional,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_community.chat_models.litellm import (
ChatLiteLLM,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
)
token_usage_key_name = "token_usage"
model_extra_key_name = "model_extra"
def get_llm_output(usage: Any, **params: Any) -> dict:
"""Get llm output from usage and params."""
llm_output = {token_usage_key_name: usage}
# copy over metadata (metadata came from router completion call)
metadata = params["metadata"]
for key in metadata:
if key not in llm_output:
# if token usage in metadata, prefer metadata's copy of it
llm_output[key] = metadata[key]
return llm_output
class ChatLiteLLMRouter(ChatLiteLLM):
"""LiteLLM Router as LangChain Model."""
router: Any
def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router."""
super().__init__(**kwargs)
self.router = router
@property
def _llm_type(self) -> str:
return "LiteLLMRouter"
def _set_model_for_completion(self) -> None:
# use first model name (aka: model group),
# since we can only pass one to the router completion functions
self.model = self.router.model_list[0]["model_name"]
def _prepare_params_for_router(self, params: Any) -> None:
params["model"] = self.model
# allow the router to set api_base based on its model choice
api_base_key_name = "api_base"
if api_base_key_name in params and params[api_base_key_name] is None:
del params[api_base_key_name]
# add metadata so router can fill it below
params.setdefault("metadata", {})
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params)
response = self.router.completion(
messages=message_dicts,
**params,
)
return self._create_chat_result(response, **params)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params)
for chunk in self.router.completion(messages=message_dicts, **params):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, **params)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params)
async for chunk in await self.router.acompletion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content, **params)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params)
response = await self.router.acompletion(
messages=message_dicts,
**params,
)
return self._create_chat_result(response, **params)
# from
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
# but modified to handle LiteLLM Usage class
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
# get dict from LiteLLM Usage class
for k, v in token_usage.dict().items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
def _create_chat_result(
self, response: Mapping[str, Any], **params: Any
) -> ChatResult:
from litellm.utils import Usage
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", Usage(prompt_tokens=0, total_tokens=0))
llm_output = get_llm_output(token_usage, **params)
return ChatResult(generations=generations, llm_output=llm_output)

View File

@@ -0,0 +1,326 @@
"""Test LiteLLM Router API wrapper."""
import asyncio
from copy import deepcopy
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
model_group = "gpt-4"
fake_model_prefix = "azure/fake-deployment-name-"
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
fake_api_key = "fakekeyvalue"
fake_api_version = "XXXX-XX-XX"
fake_api_base = "https://faketesturl/"
fake_chunks = ["This is ", "a fake answer."]
fake_answer = "".join(fake_chunks)
token_usage_key_name = "token_usage"
model_list = [
{
"model_name": model_group,
"litellm_params": {
"model": fake_models_names[0],
"api_key": fake_api_key,
"api_version": fake_api_version,
"api_base": fake_api_base,
},
},
{
"model_name": model_group,
"litellm_params": {
"model": fake_models_names[1],
"api_key": fake_api_key,
"api_version": fake_api_version,
"api_base": fake_api_base,
},
},
]
class FakeCompletion:
def __init__(self) -> None:
self.seen_inputs: List[Any] = []
@staticmethod
def _get_new_result_and_choices(
base_result: Dict[str, Any],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
result = deepcopy(base_result)
choices = cast(List[Dict[str, Any]], result["choices"])
return result, choices
@staticmethod
def _get_next_result(
agen: AsyncGenerator[Dict[str, Any], None],
) -> Dict[str, Any]:
coroutine = cast(Coroutine, agen.__anext__())
return asyncio.run(coroutine)
async def _get_fake_results_agenerator(
self, **kwargs: Any
) -> AsyncGenerator[Dict[str, Any], None]:
from litellm import Usage
self.seen_inputs.append(kwargs)
base_result = {
"choices": [
{
"index": 0,
}
],
"created": 0,
"id": "",
"model": model_group,
"object": "chat.completion",
}
if kwargs["stream"]:
for chunk_index in range(0, len(fake_chunks)):
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["delta"] = {
"role": "assistant",
"content": fake_chunks[chunk_index],
"function_call": None,
}
choice["finish_reason"] = None
# no usage here, since no usage from OpenAI API for streaming yet
# https://community.openai.com/t/usage-info-in-api-responses/18862
yield result
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["delta"] = {}
choice["finish_reason"] = "stop"
# no usage here, since no usage from OpenAI API for streaming yet
# https://community.openai.com/t/usage-info-in-api-responses/18862
yield result
else:
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["message"] = {
"content": fake_answer,
"role": "assistant",
}
choice["finish_reason"] = "stop"
result["usage"] = Usage(
completion_tokens=1, prompt_tokens=2, total_tokens=3
)
yield result
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
agen = self._get_fake_results_agenerator(**kwargs)
if kwargs["stream"]:
results: List[Dict[str, Any]] = []
while True:
try:
results.append(self._get_next_result(agen))
except StopAsyncIteration:
break
return results
else:
# there is only one result for non-streaming
return self._get_next_result(agen)
async def acompletion(
self, **kwargs: Any
) -> Union[AsyncGenerator[Dict[str, Any], None], Dict[str, Any]]:
agen = self._get_fake_results_agenerator(**kwargs)
if kwargs["stream"]:
return agen
else:
# there is only one result for non-streaming
return await agen.__anext__()
def check_inputs(self, expected_num_calls: int) -> None:
assert len(self.seen_inputs) == expected_num_calls
for kwargs in self.seen_inputs:
metadata = kwargs["metadata"]
assert metadata["model_group"] == model_group
# LiteLLM router chooses one model name from the model_list
assert kwargs["model"] in fake_models_names
assert metadata["deployment"] in fake_models_names
assert kwargs["api_key"] == fake_api_key
assert kwargs["api_version"] == fake_api_version
assert kwargs["api_base"] == fake_api_base
@pytest.fixture
def fake_completion() -> FakeCompletion:
"""Fake AI completion for testing."""
import litellm
fake_completion = FakeCompletion()
# Turn off LiteLLM's built-in telemetry
litellm.telemetry = False
litellm.completion = fake_completion.completion
litellm.acompletion = fake_completion.acompletion
return fake_completion
@pytest.fixture
def litellm_router() -> Any:
"""LiteLLM router for testing."""
from litellm import Router
return Router(
model_list=model_list,
)
@pytest.mark.scheduled
def test_litellm_router_call(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test valid call to LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
fake_completion.check_inputs(expected_num_calls=1)
@pytest.mark.scheduled
def test_litellm_router_generate(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test generate method of LiteLLM Router."""
from litellm import Usage
chat = ChatLiteLLMRouter(router=litellm_router)
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
assert chat_messages == messages_copy
assert result.llm_output is not None
assert result.llm_output[token_usage_key_name] == Usage(
completion_tokens=1, prompt_tokens=2, total_tokens=3
)
fake_completion.check_inputs(expected_num_calls=1)
@pytest.mark.scheduled
def test_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test streaming tokens from LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router, streaming=True)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
fake_completion.check_inputs(expected_num_calls=1)
@pytest.mark.scheduled
def test_litellm_router_streaming_callback(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter(
router=litellm_router,
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 10 words.")
response = chat([message])
assert callback_handler.llm_streams > 1
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
fake_completion.check_inputs(expected_num_calls=1)
@pytest.mark.asyncio
@pytest.mark.scheduled
async def test_async_litellm_router(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test async generation."""
from litellm import Usage
chat = ChatLiteLLMRouter(router=litellm_router)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
assert response.llm_output is not None
assert response.llm_output[token_usage_key_name] == Usage(
completion_tokens=2, prompt_tokens=4, total_tokens=6
)
fake_completion.check_inputs(expected_num_calls=2)
@pytest.mark.asyncio
@pytest.mark.scheduled
async def test_async_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter(
router=litellm_router,
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
# no usage check here, since no usage from OpenAI API for streaming yet
# https://community.openai.com/t/usage-info-in-api-responses/18862
fake_completion.check_inputs(expected_num_calls=2)

View File

@@ -22,6 +22,7 @@ EXPECTED_ALL = [
"MiniMaxChat",
"ChatAnyscale",
"ChatLiteLLM",
"ChatLiteLLMRouter",
"ErnieBotChat",
"ChatJavelinAIGateway",
"ChatKonko",