mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-14 02:48:54 +00:00
community[minor]: Integrating GPTRouter (#14900)
**Description:** Adding a langchain integration for [GPTRouter](https://gpt-router.writesonic.com/) 🚀 , **Tag maintainer:** @Gupta-Anubhav12 @samanyougarg @sirjan-ws-ext **Twitter handle:** [@SamanyouGarg](https://twitter.com/SamanyouGarg) Integration Tests Passing: <img width="1137" alt="Screenshot 2023-12-19 at 5 45 31 PM" src="https://github.com/Writesonic/langchain/assets/151817113/4a59df9a-ee30-47aa-9df9-b8c4eeb9dc76">
This commit is contained in:
parent
1069a93d18
commit
44cb899a93
231
docs/docs/integrations/chat/gpt_router.ipynb
Normal file
231
docs/docs/integrations/chat/gpt_router.ipynb
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"id": "59148044",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: GPTRouter\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatGPTRouter\n",
|
||||||
|
"\n",
|
||||||
|
"[GPTRouter](https://github.com/Writesonic/GPTRouter) is an open source LLM API Gateway that offers a universal API for 30+ LLMs, vision, and image models, with smart fallbacks based on uptime and latency, automatic retries, and streaming.\n",
|
||||||
|
"\n",
|
||||||
|
" \n",
|
||||||
|
"This notebook covers how to get started with using Langchain + the GPTRouter I/O library. \n",
|
||||||
|
"\n",
|
||||||
|
"* Set `GPT_ROUTER_API_KEY` environment variable\n",
|
||||||
|
"* or use the `gpt_router_api_key` keyword argument"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "d0133ddd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: GPTRouter in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (0.1.3)\n",
|
||||||
|
"Requirement already satisfied: pydantic==2.5.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (2.5.2)\n",
|
||||||
|
"Requirement already satisfied: httpx>=0.25.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (0.25.2)\n",
|
||||||
|
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (0.6.0)\n",
|
||||||
|
"Requirement already satisfied: pydantic-core==2.14.5 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (2.14.5)\n",
|
||||||
|
"Requirement already satisfied: typing-extensions>=4.6.1 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (4.8.0)\n",
|
||||||
|
"Requirement already satisfied: idna in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.6)\n",
|
||||||
|
"Requirement already satisfied: anyio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.7.1)\n",
|
||||||
|
"Requirement already satisfied: sniffio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.3.0)\n",
|
||||||
|
"Requirement already satisfied: certifi in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (2023.11.17)\n",
|
||||||
|
"Requirement already satisfied: httpcore==1.* in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.0.2)\n",
|
||||||
|
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.25.2->GPTRouter) (0.14.0)\n",
|
||||||
|
"Requirement already satisfied: exceptiongroup in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from anyio->httpx>=0.25.2->GPTRouter) (1.2.0)\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||||
|
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%pip install GPTRouter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema import HumanMessage\n",
|
||||||
|
"from langchain_community.chat_models import ChatGPTRouter\n",
|
||||||
|
"from langchain_community.chat_models.gpt_router import GPTRouterModel"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "b8a9914b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"anthropic_claude = GPTRouterModel(name=\"claude-instant-1.2\", provider_name=\"anthropic\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatGPTRouter(models_priority_list=[anthropic_claude])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\" J'aime programmer.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## `ChatGPTRouter` also supports async and streaming functionality:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.manager import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info={'finish_reason': 'stop_sequence'}, message=AIMessage(content=\" J'aime programmer.\"))]], llm_output={}, run=[RunInfo(run_id=UUID('9885f27f-c35a-4434-9f37-c254259762a5'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.agenerate([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" J'aime programmer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\" J'aime programmer.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatGPTRouter(\n",
|
||||||
|
" models_priority_list=[anthropic_claude],\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
|
")\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -31,6 +31,7 @@ from langchain_community.chat_models.fake import FakeListChatModel
|
|||||||
from langchain_community.chat_models.fireworks import ChatFireworks
|
from langchain_community.chat_models.fireworks import ChatFireworks
|
||||||
from langchain_community.chat_models.gigachat import GigaChat
|
from langchain_community.chat_models.gigachat import GigaChat
|
||||||
from langchain_community.chat_models.google_palm import ChatGooglePalm
|
from langchain_community.chat_models.google_palm import ChatGooglePalm
|
||||||
|
from langchain_community.chat_models.gpt_router import ChatGPTRouter
|
||||||
from langchain_community.chat_models.human import HumanInputChatModel
|
from langchain_community.chat_models.human import HumanInputChatModel
|
||||||
from langchain_community.chat_models.hunyuan import ChatHunyuan
|
from langchain_community.chat_models.hunyuan import ChatHunyuan
|
||||||
from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
|
from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
|
||||||
@ -79,4 +80,5 @@ __all__ = [
|
|||||||
"ChatHunyuan",
|
"ChatHunyuan",
|
||||||
"GigaChat",
|
"GigaChat",
|
||||||
"VolcEngineMaasChat",
|
"VolcEngineMaasChat",
|
||||||
|
"ChatGPTRouter",
|
||||||
]
|
]
|
||||||
|
390
libs/community/langchain_community/chat_models/gpt_router.py
Normal file
390
libs/community/langchain_community/chat_models/gpt_router.py
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||||
|
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
from langchain_community.adapters.openai import (
|
||||||
|
convert_dict_to_message,
|
||||||
|
convert_message_to_dict,
|
||||||
|
)
|
||||||
|
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com"
|
||||||
|
|
||||||
|
|
||||||
|
class GPTRouterException(Exception):
|
||||||
|
"""Error with the `GPTRouter APIs`"""
|
||||||
|
|
||||||
|
|
||||||
|
class GPTRouterModel(BaseModel):
|
||||||
|
name: str
|
||||||
|
provider_name: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_ordered_generation_requests(
|
||||||
|
models_priority_list: List[GPTRouterModel], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return the body for the model router input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from gpt_router.models import GenerationParams, ModelGenerationRequest
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModelGenerationRequest(
|
||||||
|
model_name=model.name,
|
||||||
|
provider_name=model.provider_name,
|
||||||
|
order=index + 1,
|
||||||
|
prompt_params=GenerationParams(**kwargs),
|
||||||
|
)
|
||||||
|
for index, model in enumerate(models_priority_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
llm: ChatGPTRouter,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
from gpt_router import exceptions
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
exceptions.GPTRouterApiTimeoutError,
|
||||||
|
exceptions.GPTRouterInternalServerError,
|
||||||
|
exceptions.GPTRouterNotAvailableError,
|
||||||
|
exceptions.GPTRouterTooManyRequestsError,
|
||||||
|
]
|
||||||
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
|
llm: ChatGPTRouter,
|
||||||
|
models_priority_list: List[GPTRouterModel],
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
ordered_generation_requests = get_ordered_generation_requests(
|
||||||
|
models_priority_list, **kwargs
|
||||||
|
)
|
||||||
|
return llm.client.generate(
|
||||||
|
ordered_generation_requests=ordered_generation_requests,
|
||||||
|
is_stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
llm: ChatGPTRouter,
|
||||||
|
models_priority_list: List[GPTRouterModel],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
ordered_generation_requests = get_ordered_generation_requests(
|
||||||
|
models_priority_list, **kwargs
|
||||||
|
)
|
||||||
|
return await llm.client.agenerate(
|
||||||
|
ordered_generation_requests=ordered_generation_requests,
|
||||||
|
is_stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTRouter(BaseChatModel):
|
||||||
|
"""GPTRouter by Writesonic Inc.
|
||||||
|
|
||||||
|
For more information, see https://gpt-router.writesonic.com/docs
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
|
||||||
|
gpt_router_api_base: str = Field(default=None)
|
||||||
|
"""WriteSonic GPTRouter custom endpoint"""
|
||||||
|
gpt_router_api_key: Optional[str] = None
|
||||||
|
"""WriteSonic GPTRouter API Key"""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
max_retries: int = 4
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
n: int = 1
|
||||||
|
"""Number of chat completions to generate for each prompt."""
|
||||||
|
max_tokens: int = 256
|
||||||
|
|
||||||
|
@root_validator(allow_reuse=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["gpt_router_api_base"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"gpt_router_api_base",
|
||||||
|
"GPT_ROUTER_API_BASE",
|
||||||
|
DEFAULT_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
values["gpt_router_api_key"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"gpt_router_api_key",
|
||||||
|
"GPT_ROUTER_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gpt_router.client import GPTRouterClient
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise GPTRouterException(
|
||||||
|
"Could not import GPTRouter python package. "
|
||||||
|
"Please install it with `pip install GPTRouter`."
|
||||||
|
)
|
||||||
|
|
||||||
|
gpt_router_client = GPTRouterClient(
|
||||||
|
values["gpt_router_api_base"], values["gpt_router_api_key"]
|
||||||
|
)
|
||||||
|
values["client"] = gpt_router_client
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"gpt_router_api_key": "GPT_ROUTER_API_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "gpt-router-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
**{"models_priority_list": self.models_priority_list},
|
||||||
|
**self._default_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling GPTRouter API."""
|
||||||
|
return {
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"n": self.n,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": False}
|
||||||
|
response = completion_with_retry(
|
||||||
|
self,
|
||||||
|
messages=message_dicts,
|
||||||
|
models_priority_list=self.models_priority_list,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": False}
|
||||||
|
response = await acompletion_with_retry(
|
||||||
|
self,
|
||||||
|
messages=message_dicts,
|
||||||
|
models_priority_list=self.models_priority_list,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _create_chat_generation_chunk(
|
||||||
|
self, data: Mapping[str, Any], default_chunk_class
|
||||||
|
):
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
{"content": data.get("text", "")}, default_chunk_class
|
||||||
|
)
|
||||||
|
finish_reason = data.get("finish_reason")
|
||||||
|
generation_info = (
|
||||||
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
|
return chunk, default_chunk_class
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
generator_response = completion_with_retry(
|
||||||
|
self,
|
||||||
|
messages=message_dicts,
|
||||||
|
models_priority_list=self.models_priority_list,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
for chunk in generator_response:
|
||||||
|
if chunk.event != "update":
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||||
|
chunk.data, default_chunk_class
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
token=chunk.message.content, chunk=chunk.message
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
generator_response = acompletion_with_retry(
|
||||||
|
self,
|
||||||
|
messages=message_dicts,
|
||||||
|
models_priority_list=self.models_priority_list,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
async for chunk in await generator_response:
|
||||||
|
if chunk.event != "update":
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||||
|
chunk.data, default_chunk_class
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk.message.content, chunk=chunk.message
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = self._default_params
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: GenerationResponse) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response.choices:
|
||||||
|
message = convert_dict_to_message(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": res.text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=dict(finish_reason=res.finish_reason),
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
llm_output = {"token_usage": response.meta, "model": response.model}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
@ -0,0 +1,84 @@
|
|||||||
|
"""Test GPTRouter API wrapper."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
|
from langchain_community.chat_models.gpt_router import ChatGPTRouter, GPTRouterModel
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_router_call() -> None:
|
||||||
|
"""Test valid call to GPTRouter."""
|
||||||
|
anthropic_claude = GPTRouterModel(
|
||||||
|
name="claude-instant-1.2", provider_name="anthropic"
|
||||||
|
)
|
||||||
|
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||||
|
message = HumanMessage(content="Hello World")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_router_call_incorrect_model() -> None:
|
||||||
|
"""Test invalid modelName"""
|
||||||
|
anthropic_claude = GPTRouterModel(
|
||||||
|
name="model_does_not_exist", provider_name="anthropic"
|
||||||
|
)
|
||||||
|
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||||
|
message = HumanMessage(content="Hello World")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
chat([message])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_router_generate() -> None:
|
||||||
|
"""Test generate method of GPTRouter."""
|
||||||
|
anthropic_claude = GPTRouterModel(
|
||||||
|
name="claude-instant-1.2", provider_name="anthropic"
|
||||||
|
)
|
||||||
|
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||||
|
chat_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content="If (5 + x = 18), what is x?")]
|
||||||
|
]
|
||||||
|
messages_copy = [messages.copy() for messages in chat_messages]
|
||||||
|
result: LLMResult = chat.generate(chat_messages)
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
for response in result.generations[0]:
|
||||||
|
assert isinstance(response, ChatGeneration)
|
||||||
|
assert isinstance(response.text, str)
|
||||||
|
assert response.text == response.message.content
|
||||||
|
assert chat_messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_router_streaming() -> None:
|
||||||
|
"""Test streaming tokens from GPTRouter."""
|
||||||
|
anthropic_claude = GPTRouterModel(
|
||||||
|
name="claude-instant-1.2", provider_name="anthropic"
|
||||||
|
)
|
||||||
|
chat = ChatGPTRouter(models_priority_list=[anthropic_claude], streaming=True)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_router_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
anthropic_claude = GPTRouterModel(
|
||||||
|
name="claude-instant-1.2", provider_name="anthropic"
|
||||||
|
)
|
||||||
|
chat = ChatGPTRouter(
|
||||||
|
models_priority_list=[anthropic_claude],
|
||||||
|
streaming=True,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Write me a 5 line poem.")
|
||||||
|
chat([message])
|
||||||
|
assert callback_handler.llm_streams > 1
|
@ -31,6 +31,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatHunyuan",
|
"ChatHunyuan",
|
||||||
"GigaChat",
|
"GigaChat",
|
||||||
"VolcEngineMaasChat",
|
"VolcEngineMaasChat",
|
||||||
|
"ChatGPTRouter",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user