refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
from typing import AsyncIterator, List, Optional
from dbgpt.core.interface.llm import (
DefaultMessageConverter,
LLMClient,
MessageConverter,
ModelMetadata,
@@ -13,14 +14,28 @@ from dbgpt.model.parameter import WorkerType
class DefaultLLMClient(LLMClient):
def __init__(self, worker_manager: WorkerManager):
"""Default LLM client implementation.
Connect to the worker manager and send the request to the worker manager.
Args:
worker_manager (WorkerManager): worker manager instance.
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False.
"""
def __init__(
self, worker_manager: WorkerManager, auto_convert_message: bool = False
):
self._worker_manager = worker_manager
self._auto_covert_message = auto_convert_message
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
if not message_converter and self._auto_covert_message:
message_converter = DefaultMessageConverter()
request = await self.covert_message(request, message_converter)
return await self._worker_manager.generate(request.to_dict())
@@ -29,6 +44,8 @@ class DefaultLLMClient(LLMClient):
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
if not message_converter and self._auto_covert_message:
message_converter = DefaultMessageConverter()
request = await self.covert_message(request, message_converter)
async for output in self._worker_manager.generate_stream(request.to_dict()):
yield output