mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 12:38:45 +00:00
community[minor]: Integrating GPTRouter (#14900)
**Description:** Adding a langchain integration for [GPTRouter](https://gpt-router.writesonic.com/) 🚀 , **Tag maintainer:** @Gupta-Anubhav12 @samanyougarg @sirjan-ws-ext **Twitter handle:** [@SamanyouGarg](https://twitter.com/SamanyouGarg) Integration Tests Passing: <img width="1137" alt="Screenshot 2023-12-19 at 5 45 31 PM" src="https://github.com/Writesonic/langchain/assets/151817113/4a59df9a-ee30-47aa-9df9-b8c4eeb9dc76">
This commit is contained in:
parent
1069a93d18
commit
44cb899a93
231
docs/docs/integrations/chat/gpt_router.ipynb
Normal file
231
docs/docs/integrations/chat/gpt_router.ipynb
Normal file
@ -0,0 +1,231 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "59148044",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: GPTRouter\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatGPTRouter\n",
|
||||
"\n",
|
||||
"[GPTRouter](https://github.com/Writesonic/GPTRouter) is an open source LLM API Gateway that offers a universal API for 30+ LLMs, vision, and image models, with smart fallbacks based on uptime and latency, automatic retries, and streaming.\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"This notebook covers how to get started with using Langchain + the GPTRouter I/O library. \n",
|
||||
"\n",
|
||||
"* Set `GPT_ROUTER_API_KEY` environment variable\n",
|
||||
"* or use the `gpt_router_api_key` keyword argument"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "d0133ddd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: GPTRouter in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (0.1.3)\n",
|
||||
"Requirement already satisfied: pydantic==2.5.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (2.5.2)\n",
|
||||
"Requirement already satisfied: httpx>=0.25.2 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from GPTRouter) (0.25.2)\n",
|
||||
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (0.6.0)\n",
|
||||
"Requirement already satisfied: pydantic-core==2.14.5 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (2.14.5)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.6.1 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from pydantic==2.5.2->GPTRouter) (4.8.0)\n",
|
||||
"Requirement already satisfied: idna in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.6)\n",
|
||||
"Requirement already satisfied: anyio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (3.7.1)\n",
|
||||
"Requirement already satisfied: sniffio in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.3.0)\n",
|
||||
"Requirement already satisfied: certifi in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (2023.11.17)\n",
|
||||
"Requirement already satisfied: httpcore==1.* in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpx>=0.25.2->GPTRouter) (1.0.2)\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.25.2->GPTRouter) (0.14.0)\n",
|
||||
"Requirement already satisfied: exceptiongroup in /Users/sirjan-ws/.pyenv/versions/3.10.13/envs/langchain_venv5/lib/python3.10/site-packages (from anyio->httpx>=0.25.2->GPTRouter) (1.2.0)\n",
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install GPTRouter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"from langchain_community.chat_models import ChatGPTRouter\n",
|
||||
"from langchain_community.chat_models.gpt_router import GPTRouterModel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "b8a9914b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"anthropic_claude = GPTRouterModel(name=\"claude-instant-1.2\", provider_name=\"anthropic\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatGPTRouter(models_priority_list=[anthropic_claude])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\" J'aime programmer.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `ChatGPTRouter` also supports async and streaming functionality:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.manager import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info={'finish_reason': 'stop_sequence'}, message=AIMessage(content=\" J'aime programmer.\"))]], llm_output={}, run=[RunInfo(run_id=UUID('9885f27f-c35a-4434-9f37-c254259762a5'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await chat.agenerate([messages])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" J'aime programmer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\" J'aime programmer.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatGPTRouter(\n",
|
||||
" models_priority_list=[anthropic_claude],\n",
|
||||
" streaming=True,\n",
|
||||
" verbose=True,\n",
|
||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||
")\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -31,6 +31,7 @@ from langchain_community.chat_models.fake import FakeListChatModel
|
||||
from langchain_community.chat_models.fireworks import ChatFireworks
|
||||
from langchain_community.chat_models.gigachat import GigaChat
|
||||
from langchain_community.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain_community.chat_models.gpt_router import ChatGPTRouter
|
||||
from langchain_community.chat_models.human import HumanInputChatModel
|
||||
from langchain_community.chat_models.hunyuan import ChatHunyuan
|
||||
from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGateway
|
||||
@ -79,4 +80,5 @@ __all__ = [
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
"ChatGPTRouter",
|
||||
]
|
||||
|
390
libs/community/langchain_community/chat_models/gpt_router.py
Normal file
390
libs/community/langchain_community/chat_models/gpt_router.py
Normal file
@ -0,0 +1,390 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_dict_to_message,
|
||||
convert_message_to_dict,
|
||||
)
|
||||
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com"
|
||||
|
||||
|
||||
class GPTRouterException(Exception):
|
||||
"""Error with the `GPTRouter APIs`"""
|
||||
|
||||
|
||||
class GPTRouterModel(BaseModel):
|
||||
name: str
|
||||
provider_name: str
|
||||
|
||||
|
||||
def get_ordered_generation_requests(
|
||||
models_priority_list: List[GPTRouterModel], **kwargs
|
||||
):
|
||||
"""
|
||||
Return the body for the model router input.
|
||||
"""
|
||||
|
||||
from gpt_router.models import GenerationParams, ModelGenerationRequest
|
||||
|
||||
return [
|
||||
ModelGenerationRequest(
|
||||
model_name=model.name,
|
||||
provider_name=model.provider_name,
|
||||
order=index + 1,
|
||||
prompt_params=GenerationParams(**kwargs),
|
||||
)
|
||||
for index, model in enumerate(models_priority_list)
|
||||
]
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatGPTRouter,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
from gpt_router import exceptions
|
||||
|
||||
errors = [
|
||||
exceptions.GPTRouterApiTimeoutError,
|
||||
exceptions.GPTRouterInternalServerError,
|
||||
exceptions.GPTRouterNotAvailableError,
|
||||
exceptions.GPTRouterTooManyRequestsError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: ChatGPTRouter,
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
ordered_generation_requests = get_ordered_generation_requests(
|
||||
models_priority_list, **kwargs
|
||||
)
|
||||
return llm.client.generate(
|
||||
ordered_generation_requests=ordered_generation_requests,
|
||||
is_stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatGPTRouter,
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
ordered_generation_requests = get_ordered_generation_requests(
|
||||
models_priority_list, **kwargs
|
||||
)
|
||||
return await llm.client.agenerate(
|
||||
ordered_generation_requests=ordered_generation_requests,
|
||||
is_stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class ChatGPTRouter(BaseChatModel):
|
||||
"""GPTRouter by Writesonic Inc.
|
||||
|
||||
For more information, see https://gpt-router.writesonic.com/docs
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
|
||||
gpt_router_api_base: str = Field(default=None)
|
||||
"""WriteSonic GPTRouter custom endpoint"""
|
||||
gpt_router_api_key: Optional[str] = None
|
||||
"""WriteSonic GPTRouter API Key"""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
max_retries: int = 4
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 256
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["gpt_router_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"gpt_router_api_base",
|
||||
"GPT_ROUTER_API_BASE",
|
||||
DEFAULT_API_BASE_URL,
|
||||
)
|
||||
|
||||
values["gpt_router_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"gpt_router_api_key",
|
||||
"GPT_ROUTER_API_KEY",
|
||||
)
|
||||
|
||||
try:
|
||||
from gpt_router.client import GPTRouterClient
|
||||
|
||||
except ImportError:
|
||||
raise GPTRouterException(
|
||||
"Could not import GPTRouter python package. "
|
||||
"Please install it with `pip install GPTRouter`."
|
||||
)
|
||||
|
||||
gpt_router_client = GPTRouterClient(
|
||||
values["gpt_router_api_base"], values["gpt_router_api_key"]
|
||||
)
|
||||
values["client"] = gpt_router_client
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"gpt_router_api_key": "GPT_ROUTER_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "gpt-router-chat"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"models_priority_list": self.models_priority_list},
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling GPTRouter API."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": False}
|
||||
response = completion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": False}
|
||||
response = await acompletion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_generation_chunk(
|
||||
self, data: Mapping[str, Any], default_chunk_class
|
||||
):
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
{"content": data.get("text", "")}, default_chunk_class
|
||||
)
|
||||
finish_reason = data.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
return chunk, default_chunk_class
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
generator_response = completion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
for chunk in generator_response:
|
||||
if chunk.event != "update":
|
||||
continue
|
||||
|
||||
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||
chunk.data, default_chunk_class
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=chunk.message.content, chunk=chunk.message
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
generator_response = acompletion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
async for chunk in await generator_response:
|
||||
if chunk.event != "update":
|
||||
continue
|
||||
|
||||
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||
chunk.data, default_chunk_class
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.message.content, chunk=chunk.message
|
||||
)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: GenerationResponse) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": res.text,
|
||||
}
|
||||
)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
)
|
||||
generations.append(gen)
|
||||
llm_output = {"token_usage": response.meta, "model": response.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
@ -0,0 +1,84 @@
|
||||
"""Test GPTRouter API wrapper."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models.gpt_router import ChatGPTRouter, GPTRouterModel
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_gpt_router_call() -> None:
|
||||
"""Test valid call to GPTRouter."""
|
||||
anthropic_claude = GPTRouterModel(
|
||||
name="claude-instant-1.2", provider_name="anthropic"
|
||||
)
|
||||
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||
message = HumanMessage(content="Hello World")
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_gpt_router_call_incorrect_model() -> None:
|
||||
"""Test invalid modelName"""
|
||||
anthropic_claude = GPTRouterModel(
|
||||
name="model_does_not_exist", provider_name="anthropic"
|
||||
)
|
||||
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||
message = HumanMessage(content="Hello World")
|
||||
with pytest.raises(Exception):
|
||||
chat([message])
|
||||
|
||||
|
||||
def test_gpt_router_generate() -> None:
|
||||
"""Test generate method of GPTRouter."""
|
||||
anthropic_claude = GPTRouterModel(
|
||||
name="claude-instant-1.2", provider_name="anthropic"
|
||||
)
|
||||
chat = ChatGPTRouter(models_priority_list=[anthropic_claude])
|
||||
chat_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content="If (5 + x = 18), what is x?")]
|
||||
]
|
||||
messages_copy = [messages.copy() for messages in chat_messages]
|
||||
result: LLMResult = chat.generate(chat_messages)
|
||||
assert isinstance(result, LLMResult)
|
||||
for response in result.generations[0]:
|
||||
assert isinstance(response, ChatGeneration)
|
||||
assert isinstance(response.text, str)
|
||||
assert response.text == response.message.content
|
||||
assert chat_messages == messages_copy
|
||||
|
||||
|
||||
def test_gpt_router_streaming() -> None:
|
||||
"""Test streaming tokens from GPTRouter."""
|
||||
anthropic_claude = GPTRouterModel(
|
||||
name="claude-instant-1.2", provider_name="anthropic"
|
||||
)
|
||||
chat = ChatGPTRouter(models_priority_list=[anthropic_claude], streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_gpt_router_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
anthropic_claude = GPTRouterModel(
|
||||
name="claude-instant-1.2", provider_name="anthropic"
|
||||
)
|
||||
chat = ChatGPTRouter(
|
||||
models_priority_list=[anthropic_claude],
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Write me a 5 line poem.")
|
||||
chat([message])
|
||||
assert callback_handler.llm_streams > 1
|
@ -31,6 +31,7 @@ EXPECTED_ALL = [
|
||||
"ChatHunyuan",
|
||||
"GigaChat",
|
||||
"VolcEngineMaasChat",
|
||||
"ChatGPTRouter",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user