community[minor]: Integrating GPTRouter (#14900)

**Description:** Adding a langchain integration for
[GPTRouter](https://gpt-router.writesonic.com/) 🚀 ,
 **Tag maintainer:** @Gupta-Anubhav12 @samanyougarg @sirjan-ws-ext  
 **Twitter handle:** [@SamanyouGarg](https://twitter.com/SamanyouGarg)
 
Integration Tests Passing:
<img width="1137" alt="Screenshot 2023-12-19 at 5 45 31 PM"
src="https://github.com/Writesonic/langchain/assets/151817113/4a59df9a-ee30-47aa-9df9-b8c4eeb9dc76">
This commit is contained in:
Sirjanpreet Singh Banga 2023-12-19 20:38:36 +05:30 committed by GitHub
parent 1069a93d18
commit 44cb899a93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 708 additions and 0 deletions

View File

@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "raw",
"id": "59148044",
"metadata": {},
"source": [
"---\n",
"sidebar_label: GPTRouter\n",
"---"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# ChatGPTRouter\n",
"\n",
"[GPTRouter](https://github.com/Writesonic/GPTRouter) is an open source LLM API Gateway that offers a universal API for 30+ LLMs, vision, and image models, with smart fallbacks based on uptime and latency, automatic retries, and streaming.\n",
"\n",
" \n",
"This notebook covers how to get started with using Langchain + the GPTRouter I/O library. \n",
"\n",
"* Set `GPT_ROUTER_API_KEY` environment variable\n",
"* or use the `gpt_router_api_key` keyword argument"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d0133ddd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: GPTRouter in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (0.1.3)\n",
"Requirement already satisfied: pydantic==2.5.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (2.5.2)\n",
"Requirement already satisfied: httpx>=0.25.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (0.25.2)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.14.5 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (2.14.5)\n",
"Requirement already satisfied: typing-extensions>=4.6.1 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (4.8.0)\n",
"Requirement already satisfied: idna in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.6)\n",
"Requirement already satisfied: anyio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.7.1)\n",
"Requirement already satisfied: sniffio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.3.0)\n",
"Requirement already satisfied: certifi in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (2023.11.17)\n",
"Requirement already satisfied: httpcore==1.* in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.0.2)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.25.2->GPTRouter) (0.14.0)\n",
"Requirement already satisfied: exceptiongroup in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from anyio->httpx>=0.25.2->GPTRouter) (1.2.0)\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install GPTRouter"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.schema import HumanMessage\n",
"from langchain_community.chat_models import ChatGPTRouter\n",
"from langchain_community.chat_models.gpt_router import GPTRouterModel"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b8a9914b",
"metadata": {},
"outputs": [],
"source": [
"anthropic_claude = GPTRouterModel(name=\"claude-instant-1.2\", provider_name=\"anthropic\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = ChatGPTRouter(models_priority_list=[anthropic_claude])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" J'aime programmer.\")"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n",
"]\n",
"chat(messages)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
"metadata": {},
"source": [
"## `ChatGPTRouter` also supports async and streaming functionality:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.callbacks.manager import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info={'finish_reason': 'stop_sequence'}, message=AIMessage(content=\" J'aime programmer.\"))]], llm_output={}, run=[RunInfo(run_id=UUID('9885f27f-c35a-4434-9f37-c254259762a5'))])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.agenerate([messages])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" J'aime programmer."
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\" J'aime programmer.\")"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat = ChatGPTRouter(\n",
" models_priority_list=[anthropic_claude],\n",
" streaming=True,\n",
" verbose=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")\n",
"chat(messages)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -31,6 +31,7 @@ from langchain_community.chat_models.fake import FakeListChatModel
from langchain_community.chat_models.fireworks import ChatFireworks
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.chat_models.google_palm import ChatGooglePalm
from langchain_community.chat_models.gpt_router import ChatGPTRouter
from langchain_community.chat_models.human import HumanInputChatModel
from langchain_community.chat_models.hunyuan import ChatHunyuan
from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
@ -79,4 +80,5 @@ __all__ = [
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
"ChatGPTRouter",
]

View File

@ -0,0 +1,390 @@
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Tuple,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.adapters.openai import (
convert_dict_to_message,
convert_message_to_dict,
)
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
if TYPE_CHECKING:
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
logger = logging.getLogger(__name__)
DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com"
class GPTRouterException(Exception):
"""Error with the `GPTRouter APIs`"""
class GPTRouterModel(BaseModel):
name: str
provider_name: str
def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs
):
"""
Return the body for the model router input.
"""
from gpt_router.models import GenerationParams, ModelGenerationRequest
return [
ModelGenerationRequest(
model_name=model.name,
provider_name=model.provider_name,
order=index + 1,
prompt_params=GenerationParams(**kwargs),
)
for index, model in enumerate(models_priority_list)
]
def _create_retry_decorator(
llm: ChatGPTRouter,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
from gpt_router import exceptions
errors = [
exceptions.GPTRouterApiTimeoutError,
exceptions.GPTRouterInternalServerError,
exceptions.GPTRouterNotAvailableError,
exceptions.GPTRouterTooManyRequestsError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def completion_with_retry(
llm: ChatGPTRouter,
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
ordered_generation_requests = get_ordered_generation_requests(
models_priority_list, **kwargs
)
return llm.client.generate(
ordered_generation_requests=ordered_generation_requests,
is_stream=kwargs.get("stream", False),
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: ChatGPTRouter,
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
ordered_generation_requests = get_ordered_generation_requests(
models_priority_list, **kwargs
)
return await llm.client.agenerate(
ordered_generation_requests=ordered_generation_requests,
is_stream=kwargs.get("stream", False),
)
return await _completion_with_retry(**kwargs)
class ChatGPTRouter(BaseChatModel):
"""GPTRouter by Writesonic Inc.
For more information, see https://gpt-router.writesonic.com/docs
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
gpt_router_api_base: str = Field(default=None)
"""WriteSonic GPTRouter custom endpoint"""
gpt_router_api_key: Optional[str] = None
"""WriteSonic GPTRouter API Key"""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
max_retries: int = 4
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: int = 256
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
values["gpt_router_api_base"] = get_from_dict_or_env(
values,
"gpt_router_api_base",
"GPT_ROUTER_API_BASE",
DEFAULT_API_BASE_URL,
)
values["gpt_router_api_key"] = get_from_dict_or_env(
values,
"gpt_router_api_key",
"GPT_ROUTER_API_KEY",
)
try:
from gpt_router.client import GPTRouterClient
except ImportError:
raise GPTRouterException(
"Could not import GPTRouter python package. "
"Please install it with `pip install GPTRouter`."
)
gpt_router_client = GPTRouterClient(
values["gpt_router_api_base"], values["gpt_router_api_key"]
)
values["client"] = gpt_router_client
return values
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"gpt_router_api_key": "GPT_ROUTER_API_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "gpt-router-chat"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
**{"models_priority_list": self.models_priority_list},
**self._default_params,
}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling GPTRouter API."""
return {
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
stream: bool | None = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": False}
response = completion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
stream: bool | None = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": False}
response = await acompletion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
return self._create_chat_result(response)
def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class
):
chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class
)
finish_reason = data.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return chunk, default_chunk_class
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
generator_response = completion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
for chunk in generator_response:
if chunk.event != "update":
continue
chunk, default_chunk_class = self._create_chat_generation_chunk(
chunk.data, default_chunk_class
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
generator_response = acompletion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
async for chunk in await generator_response:
if chunk.event != "update":
continue
chunk, default_chunk_class = self._create_chat_generation_chunk(
chunk.data, default_chunk_class
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: GenerationResponse) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(
{
"role": "assistant",
"content": res.text,
}
)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
)
generations.append(gen)
llm_output = {"token_usage": response.meta, "model": response.model}
return ChatResult(generations=generations, llm_output=llm_output)

View File

@ -0,0 +1,84 @@
"""Test GPTRouter API wrapper."""
from typing import List
import pytest
from langchain_core.callbacks import (
CallbackManager,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.gpt_router import ChatGPTRouter, GPTRouterModel
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_gpt_router_call() -> None:
"""Test valid call to GPTRouter."""
anthropic_claude = GPTRouterModel(
name="claude-instant-1.2", provider_name="anthropic"
)
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
message = HumanMessage(content="Hello World")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_gpt_router_call_incorrect_model() -> None:
"""Test invalid modelName"""
anthropic_claude = GPTRouterModel(
name="model_does_not_exist", provider_name="anthropic"
)
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
message = HumanMessage(content="Hello World")
with pytest.raises(Exception):
chat([message])
def test_gpt_router_generate() -> None:
"""Test generate method of GPTRouter."""
anthropic_claude = GPTRouterModel(
name="claude-instant-1.2", provider_name="anthropic"
)
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="If (5 + x = 18), what is x?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
assert chat_messages == messages_copy
def test_gpt_router_streaming() -> None:
"""Test streaming tokens from GPTRouter."""
anthropic_claude = GPTRouterModel(
name="claude-instant-1.2", provider_name="anthropic"
)
chat = ChatGPTRouter(models_priority_list=[anthropic_claude], streaming=True)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_gpt_router_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
anthropic_claude = GPTRouterModel(
name="claude-instant-1.2", provider_name="anthropic"
)
chat = ChatGPTRouter(
models_priority_list=[anthropic_claude],
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Write me a 5 line poem.")
chat([message])
assert callback_handler.llm_streams > 1

View File

@ -31,6 +31,7 @@ EXPECTED_ALL = [
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
"ChatGPTRouter",
]