mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from sqlalchemy import (
|
|
Boolean,
|
|
Column,
|
|
DateTime,
|
|
Index,
|
|
Integer,
|
|
String,
|
|
Text,
|
|
and_,
|
|
desc,
|
|
func,
|
|
or_,
|
|
)
|
|
|
|
from dbgpt.storage.metadata import BaseDao, Model
|
|
|
|
|
|
class GptsMessagesEntity(Model):
|
|
__tablename__ = "gpts_messages"
|
|
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
|
|
|
conv_id = Column(
|
|
String(255), nullable=False, comment="The unique id of the conversation record"
|
|
)
|
|
sender = Column(
|
|
String(255),
|
|
nullable=False,
|
|
comment="Who speaking in the current conversation turn",
|
|
)
|
|
receiver = Column(
|
|
String(255),
|
|
nullable=False,
|
|
comment="Who receive message in the current conversation turn",
|
|
)
|
|
model_name = Column(String(255), nullable=True, comment="message generate model")
|
|
rounds = Column(Integer, nullable=False, comment="dialogue turns")
|
|
is_success = Column(Boolean, default=True, nullable=True, comment="is success")
|
|
app_code = Column(
|
|
String(255),
|
|
nullable=False,
|
|
comment="The message in which app",
|
|
)
|
|
app_name = Column(
|
|
String(255),
|
|
nullable=False,
|
|
comment="The message in which app name",
|
|
)
|
|
content = Column(
|
|
Text(length=2**31 - 1), nullable=True, comment="Content of the speech"
|
|
)
|
|
current_goal = Column(
|
|
Text, nullable=True, comment="The target corresponding to the current message"
|
|
)
|
|
context = Column(Text, nullable=True, comment="Current conversation context")
|
|
review_info = Column(
|
|
Text, nullable=True, comment="Current conversation review info"
|
|
)
|
|
action_report = Column(
|
|
Text(length=2**31 - 1),
|
|
nullable=True,
|
|
comment="Current conversation action report",
|
|
)
|
|
resource_info = Column(
|
|
Text,
|
|
nullable=True,
|
|
comment="Current conversation resource info",
|
|
)
|
|
role = Column(
|
|
String(255), nullable=True, comment="The role of the current message content"
|
|
)
|
|
|
|
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
|
updated_at = Column(
|
|
DateTime,
|
|
default=datetime.utcnow,
|
|
onupdate=datetime.utcnow,
|
|
comment="last update time",
|
|
)
|
|
__table_args__ = (Index("idx_q_messages", "conv_id", "rounds", "sender"),)
|
|
|
|
|
|
class GptsMessagesDao(BaseDao):
|
|
def append(self, entity: dict):
|
|
session = self.get_raw_session()
|
|
message = GptsMessagesEntity(
|
|
conv_id=entity.get("conv_id"),
|
|
sender=entity.get("sender"),
|
|
receiver=entity.get("receiver"),
|
|
content=entity.get("content"),
|
|
is_success=entity.get("is_success", True),
|
|
role=entity.get("role", None),
|
|
model_name=entity.get("model_name", None),
|
|
context=entity.get("context", None),
|
|
rounds=entity.get("rounds", None),
|
|
app_code=entity.get("app_code", None),
|
|
app_name=entity.get("app_name", None),
|
|
current_goal=entity.get("current_goal", None),
|
|
review_info=entity.get("review_info", None),
|
|
action_report=entity.get("action_report", None),
|
|
resource_info=entity.get("resource_info", None),
|
|
)
|
|
session.add(message)
|
|
session.commit()
|
|
id = message.id
|
|
session.close()
|
|
return id
|
|
|
|
def get_by_agent(
|
|
self, conv_id: str, agent: str
|
|
) -> Optional[List[GptsMessagesEntity]]:
|
|
session = self.get_raw_session()
|
|
gpts_messages = session.query(GptsMessagesEntity)
|
|
if agent:
|
|
gpts_messages = gpts_messages.filter(
|
|
GptsMessagesEntity.conv_id == conv_id
|
|
).filter(
|
|
or_(
|
|
GptsMessagesEntity.sender == agent,
|
|
GptsMessagesEntity.receiver == agent,
|
|
)
|
|
)
|
|
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
|
session.close()
|
|
return result
|
|
|
|
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]:
|
|
session = self.get_raw_session()
|
|
gpts_messages = session.query(GptsMessagesEntity)
|
|
if conv_id:
|
|
gpts_messages = gpts_messages.filter(GptsMessagesEntity.conv_id == conv_id)
|
|
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
|
session.close()
|
|
return result
|
|
|
|
def get_between_agents(
|
|
self,
|
|
conv_id: str,
|
|
agent1: str,
|
|
agent2: str,
|
|
current_goal: Optional[str] = None,
|
|
) -> Optional[List[GptsMessagesEntity]]:
|
|
session = self.get_raw_session()
|
|
gpts_messages = session.query(GptsMessagesEntity)
|
|
if agent1 and agent2:
|
|
gpts_messages = gpts_messages.filter(
|
|
GptsMessagesEntity.conv_id == conv_id
|
|
).filter(
|
|
or_(
|
|
and_(
|
|
GptsMessagesEntity.sender == agent1,
|
|
GptsMessagesEntity.receiver == agent2,
|
|
),
|
|
and_(
|
|
GptsMessagesEntity.sender == agent2,
|
|
GptsMessagesEntity.receiver == agent1,
|
|
),
|
|
)
|
|
)
|
|
if current_goal:
|
|
gpts_messages = gpts_messages.filter(
|
|
GptsMessagesEntity.current_goal == current_goal
|
|
)
|
|
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
|
session.close()
|
|
return result
|
|
|
|
def get_last_message(self, conv_id: str) -> Optional[GptsMessagesEntity]:
|
|
session = self.get_raw_session()
|
|
gpts_messages = session.query(GptsMessagesEntity)
|
|
if conv_id:
|
|
gpts_messages = gpts_messages.filter(
|
|
GptsMessagesEntity.conv_id == conv_id
|
|
).order_by(desc(GptsMessagesEntity.rounds))
|
|
|
|
result = gpts_messages.first()
|
|
session.close()
|
|
return result
|