mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
community[minor]: Add LiteLLM Router Integration (#15588)
community: - **Description:** - Add new ChatLiteLLMRouter class that allows a client to use a LiteLLM Router as a LangChain chat model. - Note: The existing ChatLiteLLM integration did not cover the LiteLLM Router class. - Add tests and Jupyter notebook. - **Issue:** None - **Dependencies:** Relies on existing ChatLiteLLM integration - **Twitter handle:** @bburgin_0 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
35e60728b7
commit
148347e858
218
docs/docs/integrations/chat/litellm_router.ipynb
Normal file
218
docs/docs/integrations/chat/litellm_router.ipynb
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"id": "59148044",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: LiteLLM Router\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "247da7a6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatLiteLLMRouter\n",
|
||||||
|
"\n",
|
||||||
|
"[LiteLLM](https://github.com/BerriAI/litellm) is a library that simplifies calling Anthropic, Azure, Huggingface, Replicate, etc. \n",
|
||||||
|
"\n",
|
||||||
|
"This notebook covers how to get started with using Langchain + the LiteLLM Router I/O library. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema import HumanMessage\n",
|
||||||
|
"from langchain_community.chat_models import ChatLiteLLMRouter\n",
|
||||||
|
"from litellm import Router"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_list = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"model_name\": \"gpt-4\",\n",
|
||||||
|
" \"litellm_params\": {\n",
|
||||||
|
" \"model\": \"azure/gpt-4-1106-preview\",\n",
|
||||||
|
" \"api_key\": \"<your-api-key>\",\n",
|
||||||
|
" \"api_version\": \"2023-05-15\",\n",
|
||||||
|
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"model_name\": \"gpt-4\",\n",
|
||||||
|
" \"litellm_params\": {\n",
|
||||||
|
" \"model\": \"azure/gpt-4-1106-preview\",\n",
|
||||||
|
" \"api_key\": \"<your-api-key>\",\n",
|
||||||
|
" \"api_version\": \"2023-05-15\",\n",
|
||||||
|
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
"]\n",
|
||||||
|
"litellm_router = Router(model_list=model_list)\n",
|
||||||
|
"chat = ChatLiteLLMRouter(router=litellm_router)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"J'aime programmer.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## `ChatLiteLLMRouter` also supports async and streaming functionality:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.manager import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[ChatGeneration(text=\"J'adore programmer.\", generation_info={'finish_reason': 'stop'}, message=AIMessage(content=\"J'adore programmer.\"))]], llm_output={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 19, 'total_tokens': 25}, 'model_name': None}, run=[RunInfo(run_id=UUID('75003ec9-1e2b-43b7-a216-10dcc0f75e00'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.agenerate([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"J'adore programmer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"J'adore programmer.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatLiteLLMRouter(\n",
|
||||||
|
" router=litellm_router,\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
|
")\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c253883f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -40,6 +40,7 @@ from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGate
|
|||||||
from langchain_community.chat_models.jinachat import JinaChat
|
from langchain_community.chat_models.jinachat import JinaChat
|
||||||
from langchain_community.chat_models.konko import ChatKonko
|
from langchain_community.chat_models.konko import ChatKonko
|
||||||
from langchain_community.chat_models.litellm import ChatLiteLLM
|
from langchain_community.chat_models.litellm import ChatLiteLLM
|
||||||
|
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
|
||||||
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
|
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
|
||||||
from langchain_community.chat_models.minimax import MiniMaxChat
|
from langchain_community.chat_models.minimax import MiniMaxChat
|
||||||
from langchain_community.chat_models.mlflow import ChatMlflow
|
from langchain_community.chat_models.mlflow import ChatMlflow
|
||||||
@ -78,6 +79,7 @@ __all__ = [
|
|||||||
"MiniMaxChat",
|
"MiniMaxChat",
|
||||||
"ChatAnyscale",
|
"ChatAnyscale",
|
||||||
"ChatLiteLLM",
|
"ChatLiteLLM",
|
||||||
|
"ChatLiteLLMRouter",
|
||||||
"ErnieBotChat",
|
"ErnieBotChat",
|
||||||
"ChatJavelinAIGateway",
|
"ChatJavelinAIGateway",
|
||||||
"ChatKonko",
|
"ChatKonko",
|
||||||
|
221
libs/community/langchain_community/chat_models/litellm_router.py
Normal file
221
libs/community/langchain_community/chat_models/litellm_router.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
"""LiteLLM Router as LangChain Model."""
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
agenerate_from_stream,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
ChatResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.chat_models.litellm import (
|
||||||
|
ChatLiteLLM,
|
||||||
|
_convert_delta_to_message_chunk,
|
||||||
|
_convert_dict_to_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_usage_key_name = "token_usage"
|
||||||
|
model_extra_key_name = "model_extra"
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_output(usage: Any, **params: Any) -> dict:
|
||||||
|
"""Get llm output from usage and params."""
|
||||||
|
llm_output = {token_usage_key_name: usage}
|
||||||
|
# copy over metadata (metadata came from router completion call)
|
||||||
|
metadata = params["metadata"]
|
||||||
|
for key in metadata:
|
||||||
|
if key not in llm_output:
|
||||||
|
# if token usage in metadata, prefer metadata's copy of it
|
||||||
|
llm_output[key] = metadata[key]
|
||||||
|
return llm_output
|
||||||
|
|
||||||
|
|
||||||
|
class ChatLiteLLMRouter(ChatLiteLLM):
|
||||||
|
"""LiteLLM Router as LangChain Model."""
|
||||||
|
|
||||||
|
router: Any
|
||||||
|
|
||||||
|
def __init__(self, *, router: Any, **kwargs: Any) -> None:
|
||||||
|
"""Construct Chat LiteLLM Router."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.router = router
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "LiteLLMRouter"
|
||||||
|
|
||||||
|
def _set_model_for_completion(self) -> None:
|
||||||
|
# use first model name (aka: model group),
|
||||||
|
# since we can only pass one to the router completion functions
|
||||||
|
self.model = self.router.model_list[0]["model_name"]
|
||||||
|
|
||||||
|
def _prepare_params_for_router(self, params: Any) -> None:
|
||||||
|
params["model"] = self.model
|
||||||
|
|
||||||
|
# allow the router to set api_base based on its model choice
|
||||||
|
api_base_key_name = "api_base"
|
||||||
|
if api_base_key_name in params and params[api_base_key_name] is None:
|
||||||
|
del params[api_base_key_name]
|
||||||
|
|
||||||
|
# add metadata so router can fill it below
|
||||||
|
params.setdefault("metadata", {})
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
self._set_model_for_completion()
|
||||||
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
|
response = self.router.completion(
|
||||||
|
messages=message_dicts,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response, **params)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
self._set_model_for_completion()
|
||||||
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
|
for chunk in self.router.completion(messages=message_dicts, **params):
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content, **params)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
self._set_model_for_completion()
|
||||||
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
|
async for chunk in await self.router.acompletion(
|
||||||
|
messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.content, **params)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
self._set_model_for_completion()
|
||||||
|
self._prepare_params_for_router(params)
|
||||||
|
|
||||||
|
response = await self.router.acompletion(
|
||||||
|
messages=message_dicts,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response, **params)
|
||||||
|
|
||||||
|
# from
|
||||||
|
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
|
||||||
|
# but modified to handle LiteLLM Usage class
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
system_fingerprint = None
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
if token_usage is not None:
|
||||||
|
# get dict from LiteLLM Usage class
|
||||||
|
for k, v in token_usage.dict().items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
if system_fingerprint is None:
|
||||||
|
system_fingerprint = output.get("system_fingerprint")
|
||||||
|
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
if system_fingerprint:
|
||||||
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def _create_chat_result(
|
||||||
|
self, response: Mapping[str, Any], **params: Any
|
||||||
|
) -> ChatResult:
|
||||||
|
from litellm.utils import Usage
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
for res in response["choices"]:
|
||||||
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
token_usage = response.get("usage", Usage(prompt_tokens=0, total_tokens=0))
|
||||||
|
llm_output = get_llm_output(token_usage, **params)
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
@ -0,0 +1,326 @@
|
|||||||
|
"""Test LiteLLM Router API wrapper."""
|
||||||
|
import asyncio
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
|
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
model_group = "gpt-4"
|
||||||
|
fake_model_prefix = "azure/fake-deployment-name-"
|
||||||
|
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
|
||||||
|
fake_api_key = "fakekeyvalue"
|
||||||
|
fake_api_version = "XXXX-XX-XX"
|
||||||
|
fake_api_base = "https://faketesturl/"
|
||||||
|
fake_chunks = ["This is ", "a fake answer."]
|
||||||
|
fake_answer = "".join(fake_chunks)
|
||||||
|
token_usage_key_name = "token_usage"
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": model_group,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": fake_models_names[0],
|
||||||
|
"api_key": fake_api_key,
|
||||||
|
"api_version": fake_api_version,
|
||||||
|
"api_base": fake_api_base,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": model_group,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": fake_models_names[1],
|
||||||
|
"api_key": fake_api_key,
|
||||||
|
"api_version": fake_api_version,
|
||||||
|
"api_base": fake_api_base,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCompletion:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.seen_inputs: List[Any] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_new_result_and_choices(
|
||||||
|
base_result: Dict[str, Any],
|
||||||
|
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||||
|
result = deepcopy(base_result)
|
||||||
|
choices = cast(List[Dict[str, Any]], result["choices"])
|
||||||
|
return result, choices
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_next_result(
|
||||||
|
agen: AsyncGenerator[Dict[str, Any], None],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
coroutine = cast(Coroutine, agen.__anext__())
|
||||||
|
return asyncio.run(coroutine)
|
||||||
|
|
||||||
|
async def _get_fake_results_agenerator(
|
||||||
|
self, **kwargs: Any
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
from litellm import Usage
|
||||||
|
|
||||||
|
self.seen_inputs.append(kwargs)
|
||||||
|
base_result = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 0,
|
||||||
|
"id": "",
|
||||||
|
"model": model_group,
|
||||||
|
"object": "chat.completion",
|
||||||
|
}
|
||||||
|
if kwargs["stream"]:
|
||||||
|
for chunk_index in range(0, len(fake_chunks)):
|
||||||
|
result, choices = self._get_new_result_and_choices(base_result)
|
||||||
|
choice = choices[0]
|
||||||
|
choice["delta"] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": fake_chunks[chunk_index],
|
||||||
|
"function_call": None,
|
||||||
|
}
|
||||||
|
choice["finish_reason"] = None
|
||||||
|
# no usage here, since no usage from OpenAI API for streaming yet
|
||||||
|
# https://community.openai.com/t/usage-info-in-api-responses/18862
|
||||||
|
yield result
|
||||||
|
|
||||||
|
result, choices = self._get_new_result_and_choices(base_result)
|
||||||
|
choice = choices[0]
|
||||||
|
choice["delta"] = {}
|
||||||
|
choice["finish_reason"] = "stop"
|
||||||
|
# no usage here, since no usage from OpenAI API for streaming yet
|
||||||
|
# https://community.openai.com/t/usage-info-in-api-responses/18862
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
result, choices = self._get_new_result_and_choices(base_result)
|
||||||
|
choice = choices[0]
|
||||||
|
choice["message"] = {
|
||||||
|
"content": fake_answer,
|
||||||
|
"role": "assistant",
|
||||||
|
}
|
||||||
|
choice["finish_reason"] = "stop"
|
||||||
|
result["usage"] = Usage(
|
||||||
|
completion_tokens=1, prompt_tokens=2, total_tokens=3
|
||||||
|
)
|
||||||
|
yield result
|
||||||
|
|
||||||
|
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
|
||||||
|
agen = self._get_fake_results_agenerator(**kwargs)
|
||||||
|
if kwargs["stream"]:
|
||||||
|
results: List[Dict[str, Any]] = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
results.append(self._get_next_result(agen))
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
else:
|
||||||
|
# there is only one result for non-streaming
|
||||||
|
return self._get_next_result(agen)
|
||||||
|
|
||||||
|
async def acompletion(
|
||||||
|
self, **kwargs: Any
|
||||||
|
) -> Union[AsyncGenerator[Dict[str, Any], None], Dict[str, Any]]:
|
||||||
|
agen = self._get_fake_results_agenerator(**kwargs)
|
||||||
|
if kwargs["stream"]:
|
||||||
|
return agen
|
||||||
|
else:
|
||||||
|
# there is only one result for non-streaming
|
||||||
|
return await agen.__anext__()
|
||||||
|
|
||||||
|
def check_inputs(self, expected_num_calls: int) -> None:
|
||||||
|
assert len(self.seen_inputs) == expected_num_calls
|
||||||
|
for kwargs in self.seen_inputs:
|
||||||
|
metadata = kwargs["metadata"]
|
||||||
|
|
||||||
|
assert metadata["model_group"] == model_group
|
||||||
|
|
||||||
|
# LiteLLM router chooses one model name from the model_list
|
||||||
|
assert kwargs["model"] in fake_models_names
|
||||||
|
assert metadata["deployment"] in fake_models_names
|
||||||
|
|
||||||
|
assert kwargs["api_key"] == fake_api_key
|
||||||
|
assert kwargs["api_version"] == fake_api_version
|
||||||
|
assert kwargs["api_base"] == fake_api_base
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_completion() -> FakeCompletion:
|
||||||
|
"""Fake AI completion for testing."""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
fake_completion = FakeCompletion()
|
||||||
|
|
||||||
|
# Turn off LiteLLM's built-in telemetry
|
||||||
|
litellm.telemetry = False
|
||||||
|
litellm.completion = fake_completion.completion
|
||||||
|
litellm.acompletion = fake_completion.acompletion
|
||||||
|
return fake_completion
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def litellm_router() -> Any:
|
||||||
|
"""LiteLLM router for testing."""
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
return Router(
|
||||||
|
model_list=model_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_litellm_router_call(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test valid call to LiteLLM Router."""
|
||||||
|
chat = ChatLiteLLMRouter(router=litellm_router)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
|
response = chat([message])
|
||||||
|
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.content == fake_answer
|
||||||
|
# no usage check here, since response is only an AIMessage
|
||||||
|
fake_completion.check_inputs(expected_num_calls=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_litellm_router_generate(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test generate method of LiteLLM Router."""
|
||||||
|
from litellm import Usage
|
||||||
|
|
||||||
|
chat = ChatLiteLLMRouter(router=litellm_router)
|
||||||
|
chat_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content="How many toes do dogs have?")]
|
||||||
|
]
|
||||||
|
messages_copy = [messages.copy() for messages in chat_messages]
|
||||||
|
|
||||||
|
result: LLMResult = chat.generate(chat_messages)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
for generations in result.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.message.content == generation.text
|
||||||
|
assert generation.message.content == fake_answer
|
||||||
|
assert chat_messages == messages_copy
|
||||||
|
assert result.llm_output is not None
|
||||||
|
assert result.llm_output[token_usage_key_name] == Usage(
|
||||||
|
completion_tokens=1, prompt_tokens=2, total_tokens=3
|
||||||
|
)
|
||||||
|
fake_completion.check_inputs(expected_num_calls=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_litellm_router_streaming(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test streaming tokens from LiteLLM Router."""
|
||||||
|
chat = ChatLiteLLMRouter(router=litellm_router, streaming=True)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
|
response = chat([message])
|
||||||
|
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.content == fake_answer
|
||||||
|
# no usage check here, since response is only an AIMessage
|
||||||
|
fake_completion.check_inputs(expected_num_calls=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_litellm_router_streaming_callback(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
chat = ChatLiteLLMRouter(
|
||||||
|
router=litellm_router,
|
||||||
|
streaming=True,
|
||||||
|
callbacks=[callback_handler],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Write me a sentence with 10 words.")
|
||||||
|
|
||||||
|
response = chat([message])
|
||||||
|
|
||||||
|
assert callback_handler.llm_streams > 1
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert response.content == fake_answer
|
||||||
|
# no usage check here, since response is only an AIMessage
|
||||||
|
fake_completion.check_inputs(expected_num_calls=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
async def test_async_litellm_router(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
from litellm import Usage
|
||||||
|
|
||||||
|
chat = ChatLiteLLMRouter(router=litellm_router)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.message.content == generation.text
|
||||||
|
assert generation.message.content == fake_answer
|
||||||
|
assert response.llm_output is not None
|
||||||
|
assert response.llm_output[token_usage_key_name] == Usage(
|
||||||
|
completion_tokens=2, prompt_tokens=4, total_tokens=6
|
||||||
|
)
|
||||||
|
fake_completion.check_inputs(expected_num_calls=2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
async def test_async_litellm_router_streaming(
|
||||||
|
fake_completion: FakeCompletion, litellm_router: Any
|
||||||
|
) -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
chat = ChatLiteLLMRouter(
|
||||||
|
router=litellm_router,
|
||||||
|
streaming=True,
|
||||||
|
callbacks=[callback_handler],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.message.content == generation.text
|
||||||
|
assert generation.message.content == fake_answer
|
||||||
|
# no usage check here, since no usage from OpenAI API for streaming yet
|
||||||
|
# https://community.openai.com/t/usage-info-in-api-responses/18862
|
||||||
|
fake_completion.check_inputs(expected_num_calls=2)
|
@ -22,6 +22,7 @@ EXPECTED_ALL = [
|
|||||||
"MiniMaxChat",
|
"MiniMaxChat",
|
||||||
"ChatAnyscale",
|
"ChatAnyscale",
|
||||||
"ChatLiteLLM",
|
"ChatLiteLLM",
|
||||||
|
"ChatLiteLLMRouter",
|
||||||
"ErnieBotChat",
|
"ErnieBotChat",
|
||||||
"ChatJavelinAIGateway",
|
"ChatJavelinAIGateway",
|
||||||
"ChatKonko",
|
"ChatKonko",
|
||||||
|
Loading…
Reference in New Issue
Block a user