mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
feat(model): Support claude proxy models (#2155)
This commit is contained in:
@@ -5,17 +5,11 @@ import logging
|
||||
from concurrent.futures import Executor
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMetadata,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: OpenAILLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
@@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
payload["max_tokens"] = request.max_new_tokens
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.top_p:
|
||||
payload["top_p"] = request.top_p
|
||||
return payload
|
||||
|
||||
async def generate(
|
||||
|
Reference in New Issue
Block a user