mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
247
dbgpt/storage/cache/operator.py
vendored
Normal file
247
dbgpt/storage/cache/operator.py
vendored
Normal file
@@ -0,0 +1,247 @@
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, List, Union
|
||||
|
||||
from dbgpt.core import ModelOutput, ModelRequest
|
||||
from dbgpt.core.awel import (
|
||||
BaseOperator,
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
|
||||
|
||||
|
||||
class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]):
|
||||
"""Operator for streaming processing of model outputs with caching.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager to handle caching operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
streamify: Processes a stream of inputs with cache support, yielding model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
"""Process inputs as a stream with cache support and yield model outputs.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value for the model.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
|
||||
|
||||
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
"""Operator for map-based processing of model outputs with caching.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): Manager for caching operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
map: Processes a single input with cache support and returns the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelOutput:
|
||||
"""Process a single input with cache support and return the model output.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value for the model.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The output from the model.
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
return llm_cache_value.get_value().output
|
||||
|
||||
|
||||
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
"""
|
||||
A branch operator that decides whether to use cached data or to process data using the model.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for managing cache operations.
|
||||
model_task_name (str): The name of the task to process data using the model.
|
||||
cache_task_name (str): The name of the task to process data using the cache.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager: CacheManager,
|
||||
model_task_name: str,
|
||||
cache_task_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(branches=None, **kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
self._model_task_name = model_task_name
|
||||
self._cache_task_name = cache_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
|
||||
"""
|
||||
|
||||
async def check_cache_true(input_value: ModelRequest) -> bool:
|
||||
# Check if the cache contains the result for the given input
|
||||
if input_value.context and not input_value.context.cache_enable:
|
||||
return False
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
|
||||
)
|
||||
return True if cache_value else False
|
||||
|
||||
async def check_cache_false(input_value: ModelRequest):
|
||||
# Inverse of check_cache_true
|
||||
return not await check_cache_true(input_value)
|
||||
|
||||
return {
|
||||
check_cache_true: self._cache_task_name,
|
||||
check_cache_false: self._model_task_name,
|
||||
}
|
||||
|
||||
|
||||
class ModelStreamSaveCacheOperator(
|
||||
TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
"""An operator to save the stream of model outputs to cache.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Transforms the input stream by saving the outputs to cache.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = None
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
llm_cache_key = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
outputs.append(out)
|
||||
yield out
|
||||
if llm_cache_key and _is_success_model_output(outputs):
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
|
||||
|
||||
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
"""An operator to save a single model output to cache.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Saves a single model output to cache and returns it.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The output from the model to be cached.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The same input model output.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||
if llm_cache_key and _is_success_model_output(input_value):
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
return input_value
|
||||
|
||||
|
||||
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
|
||||
"""Parses and extracts relevant fields from input to form a cache key dictionary.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input dictionary containing model and prompt parameters.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary used for generating cache keys.
|
||||
"""
|
||||
prompt: str = input_value.messages_to_string().strip()
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model_name": input_value.model,
|
||||
"temperature": input_value.temperature,
|
||||
"max_new_tokens": input_value.max_new_tokens,
|
||||
# "top_p": input_value.get("top_p", "1.0"),
|
||||
# TODO pass model_type
|
||||
# "model_type": input_value.get("model_type", "huggingface"),
|
||||
}
|
||||
|
||||
|
||||
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
|
||||
if not out:
|
||||
return False
|
||||
if isinstance(out, list):
|
||||
# check last model output
|
||||
out = out[-1]
|
||||
error_code = 0
|
||||
if isinstance(out, ModelOutput):
|
||||
error_code = out.error_code
|
||||
else:
|
||||
error_code = int(out.get("error_code", 0))
|
||||
return error_code == 0
|
@@ -41,14 +41,19 @@ class BaseQuery(orm.Query):
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
|
||||
with db.session() as session:
|
||||
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||
pagination = session.query(User).paginate_query(
|
||||
page=1, page_size=10
|
||||
)
|
||||
print(pagination)
|
||||
|
||||
|
||||
@@ -100,25 +105,37 @@ class DatabaseManager:
|
||||
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||
|
||||
db = DatabaseManager()
|
||||
# Use sqlite with memory storage.
|
||||
url = f"sqlite:///:memory:"
|
||||
engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||
engine_args = {
|
||||
"pool_size": 10,
|
||||
"max_overflow": 20,
|
||||
"pool_timeout": 30,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
db.init_db(url, engine_args=engine_args)
|
||||
|
||||
Model = create_model(db)
|
||||
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
|
||||
with db.session() as session:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
# db will commit the session automatically default.
|
||||
# session.commit()
|
||||
assert session.query(User).filter(User.name == "test").first().name == "test"
|
||||
assert (
|
||||
session.query(User).filter(User.name == "test").first().name
|
||||
== "test"
|
||||
)
|
||||
|
||||
|
||||
# More usage:
|
||||
@@ -307,6 +324,7 @@ class DatabaseManager:
|
||||
>>> db.init_default_db(sqlite_path)
|
||||
>>> with db.session() as session:
|
||||
... session.query(...)
|
||||
...
|
||||
|
||||
Args:
|
||||
sqlite_path (str): The sqlite path.
|
||||
@@ -353,12 +371,17 @@ class DatabaseManager:
|
||||
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
db = DatabaseManager.build_from("sqlite:///:memory:")
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
|
||||
db.create_all()
|
||||
with db.session() as session:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
@@ -397,7 +420,8 @@ Examples:
|
||||
>>> sqlite_path = "/tmp/dbgpt.db"
|
||||
>>> db.init_default_db(sqlite_path)
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
... session.query(...)
|
||||
...
|
||||
|
||||
>>> from dbgpt.storage.metadata import db, Model
|
||||
>>> from urllib.parse import quote_plus as urlquote, quote
|
||||
@@ -407,16 +431,24 @@ Examples:
|
||||
>>> user = "root"
|
||||
>>> password = "123456"
|
||||
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
|
||||
>>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||
>>> engine_args = {
|
||||
... "pool_size": 10,
|
||||
... "max_overflow": 20,
|
||||
... "pool_timeout": 30,
|
||||
... "pool_recycle": 3600,
|
||||
... "pool_pre_ping": True,
|
||||
... }
|
||||
>>> db.init_db(url, engine_args=engine_args)
|
||||
>>> class User(Model):
|
||||
>>> __tablename__ = "user"
|
||||
>>> id = Column(Integer, primary_key=True)
|
||||
>>> name = Column(String(50))
|
||||
>>> fullname = Column(String(50))
|
||||
... __tablename__ = "user"
|
||||
... id = Column(Integer, primary_key=True)
|
||||
... name = Column(String(50))
|
||||
... fullname = Column(String(50))
|
||||
...
|
||||
>>> with db.session() as session:
|
||||
>>> session.add(User(name="test", fullname="test"))
|
||||
>>> session.commit()
|
||||
... session.add(User(name="test", fullname="test"))
|
||||
... session.commit()
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user