mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
chore(build): Fix typo and new pre-commit config (#987)
This commit is contained in:
@@ -1,46 +1,45 @@
|
||||
from dbgpt.core.interface.cache import (
|
||||
CacheClient,
|
||||
CacheConfig,
|
||||
CacheKey,
|
||||
CachePolicy,
|
||||
CacheValue,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
LLMClient,
|
||||
ModelInferenceMetrics,
|
||||
ModelMetadata,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
ModelOutput,
|
||||
LLMClient,
|
||||
ModelMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
ConversationIdentifier,
|
||||
MessageIdentifier,
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
OnceConversation,
|
||||
StorageConversation,
|
||||
MessageStorageItem,
|
||||
ConversationIdentifier,
|
||||
MessageIdentifier,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptTemplate,
|
||||
PromptManager,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import (
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheClient,
|
||||
CachePolicy,
|
||||
CacheConfig,
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.storage import (
|
||||
DefaultStorageItemAdapter,
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
StorageItemAdapter,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
DefaultStorageItemAdapter,
|
||||
QuerySpec,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
"ModelRequest",
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
|
@@ -27,7 +27,7 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
.. code-block:: python
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]:
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
@@ -54,7 +54,7 @@ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
@@ -85,7 +85,9 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
|
||||
async def unstreamify(
|
||||
self, input_value: AsyncIterator[int]
|
||||
) -> AsyncIterator[int]:
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
||||
|
@@ -283,6 +283,7 @@ class ConversationMapperOperator(
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.operator import ConversationMapperOperator
|
||||
|
||||
messages_by_round = [
|
||||
[
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
@@ -290,7 +291,9 @@ class ConversationMapperOperator(
|
||||
],
|
||||
[
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
],
|
||||
[
|
||||
@@ -303,7 +306,9 @@ class ConversationMapperOperator(
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
@@ -315,8 +320,13 @@ class ConversationMapperOperator(
|
||||
class MyMapper(ConversationMapperOperator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
return messages_by_round[-1]
|
||||
|
||||
|
||||
operator = MyMapper()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
@@ -371,7 +381,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="human", content="Hello", round_index=1)
|
||||
]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
|
Reference in New Issue
Block a user