From 19e3679d9f9b43cdb3c161b475a1462c3200597d Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 13 Jun 2023 20:33:34 +0800 Subject: [PATCH] =?UTF-8?q?fix=20plugin=20mode=20bug=EF=BC=9BOptimize=20th?= =?UTF-8?q?e=20parsing=20logic=20for=20model=20response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pilot/common/plugins.py | 5 +- pilot/configs/config.py | 1 - pilot/configs/model_config.py | 12 ++- pilot/connections/rdbms/py_study/pd_study.py | 100 +++++++++++-------- pilot/mock_datas/db-gpt-test.db | Bin 0 -> 1060864 bytes pilot/out_parser/base.py | 36 ++++--- pilot/scene/chat_execution/chat.py | 7 +- pilot/scene/chat_execution/out_parser.py | 16 +-- pilot/scene/chat_execution/prompt.py | 17 ++-- pilot/server/__init__.py | 2 + pilot/server/bar_chart.html | 1 - pilot/server/db-gpt-test.db | Bin 12288 -> 0 bytes tests/unit/test_plugins.py | 1 + 13 files changed, 118 insertions(+), 80 deletions(-) create mode 100644 pilot/mock_datas/db-gpt-test.db delete mode 100644 pilot/server/bar_chart.html delete mode 100644 pilot/server/db-gpt-test.db diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index aeb46970d..09931c90e 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -15,6 +15,7 @@ import requests from auto_gpt_plugin_template import AutoGPTPluginTemplate from pilot.configs.config import Config +from pilot.configs.model_config import PLUGINS_DIR from pilot.logs import logger @@ -82,7 +83,7 @@ def load_native_plugins(cfg: Config): headers={'Authorization': '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'}) if response.status_code == 200: - plugins_path_path = Path(cfg.plugins_dir) + plugins_path_path = Path(PLUGINS_DIR) files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) for file in files: os.remove(file) @@ -111,7 +112,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate current_dir = os.getcwd() print(current_dir) # Generic plugins - plugins_path_path = Path(cfg.plugins_dir) + plugins_path_path = Path(PLUGINS_DIR) logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}") logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}") diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 450cb6901..06b91e33b 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -88,7 +88,6 @@ class Config(metaclass=Singleton): self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") ### The associated configuration parameters of the plug-in control the loading and use of the plug-in - self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins_openai = [] diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 36d615043..f7733a42e 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -13,8 +13,18 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") - nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path +PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") +FONT_DIR = os.path.join(PILOT_PATH, "fonts") + +# 获取当前工作目录 +current_directory = os.getcwd() +print("当前工作目录:", current_directory) + +# 设置当前工作目录 +new_directory = PILOT_PATH +os.chdir(new_directory) +print("新的工作目录:", os.getcwd()) DEVICE = ( "cuda" diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 092a44213..68784f9b7 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -9,46 +9,66 @@ from pyecharts import options as opts CFG = Config() +# +# if __name__ == "__main__": +# # 创建连接池 +# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user') +# +# # 从连接池中获取连接 +# +# +# # 归还连接到连接池中 +# +# # 执行SQL语句并将结果转化为DataFrame +# query = "SELECT * FROM users" +# df = pd.read_sql(query, engine.connect()) +# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'}) +# # 导出为HTML文件 +# with open('report.html', 'w') as f: +# f.write(df.style.render()) +# +# # # 设置中文字体 +# # font = FontProperties(fname='SimHei.ttf', size=14) +# # +# # colors = np.random.rand(df.shape[0]) +# # df.plot.scatter(x='city', y='user_name', c=colors) +# # plt.show() +# +# # 查看DataFrame +# print(df.head()) +# +# +# # 创建数据 +# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] +# y_data = [820, 932, 901, 934, 1290, 1330, 1320] +# +# # 生成图表 +# bar = ( +# Bar() +# .add_xaxis(x_data) +# .add_yaxis("销售额", y_data) +# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计")) +# ) +# +# # 生成HTML文件 +# bar.render('report.html') +# +# + if __name__ == "__main__": - # 创建连接池 - engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user') + def __extract_json(s): + i = s.index('{') + count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 + for j, c in enumerate(s[i + 1:], start=i + 1): + if c == '}': + count -= 1 + elif c == '{': + count += 1 + if count == 0: + break + assert (count == 0) # 检查是否找到最后一个'}' + return s[i:j + 1] - # 从连接池中获取连接 - - - # 归还连接到连接池中 - - # 执行SQL语句并将结果转化为DataFrame - query = "SELECT * FROM users" - df = pd.read_sql(query, engine.connect()) - df.style.set_properties(subset=['name'], **{'font-weight': 'bold'}) - # 导出为HTML文件 - with open('report.html', 'w') as f: - f.write(df.style.render()) - - # # 设置中文字体 - # font = FontProperties(fname='SimHei.ttf', size=14) - # - # colors = np.random.rand(df.shape[0]) - # df.plot.scatter(x='city', y='user_name', c=colors) - # plt.show() - - # 查看DataFrame - print(df.head()) - - - # 创建数据 - x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - y_data = [820, 932, 901, 934, 1290, 1330, 1320] - - # 生成图表 - bar = ( - Bar() - .add_xaxis(x_data) - .add_yaxis("销售额", y_data) - .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计")) - ) - - # 生成HTML文件 - bar.render('report.html') + ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" + print(__extract_json(ss)) \ No newline at end of file diff --git a/pilot/mock_datas/db-gpt-test.db b/pilot/mock_datas/db-gpt-test.db new file mode 100644 index 0000000000000000000000000000000000000000..4dd7921b4a2fb673c52b45c22e39bca6fb352772 GIT binary patch literal 1060864 zcmeI*31Ae}{lM`}!X=j|XjM>GMZp8z`_QVJiyR~&CpoA`fB->AVhAW|b+J|JEf&vO zOVt*&UiGThqfx5@9#synUS74;+7^$t*7mRe-`m~UefxGdB#~kHt(7KmY**5I_I{1Q0*~0R#|0V8<6&{PL^8 z>o0Pik!9>kPHbL13*dzS0tg_000IagfB*srAbF zgC#DCEf^v3oLN)mb#zVXYFXT6H8ot)U7gKsC$x9YY3WpRYNhc~)Jv_~I!FEb!4Neo zuTjxSwAy~L8P&hY`X{ut>y`J^D<2q}87b3_&h|M=W_O*?*1V`C;|d4IRv4qLFuS>{ zWnO#dvWzQPT~hsRUevy%txN5X1-M)GaLY-h50Ii>YPUO@mz}V0 z{w0fBI#1|fW5nFps{bzV!GoI4p? z)we&ZwXw zPtfc4HPx4TT|sxymNiOMw(7;E{n(fJ(*B_IeaT^U{`usvA2GY+v>m zov|;ut)2GS?QL^v+m}TR2D&eo$-YL$$b^`=b`W+|hh4os3@DKSJwC7RlF=kTw<5o= z%x0^p%g-;YsgZ7|F+G4|<!cZ|(wp|-AGP4k=E=GC{f+Ol?% zxdH6Uu-!6`kM8~F&#wGnu+7$e@O!ha8Z~4f`jK=QQpY{bIQ=YI=`5S=#LJJ9F&pl) zM`DF7UM>eo9}XNTKjdXypRB2Brns?8#%l5_-h5c{u^H8l_A+LpOJbYRs!6CWV)`rE zUbjlc8;J*rit#>3$~vzTOI0*RpH0(3d~w){LRQerX|7(+q0c1ZP$UdT6G>f)q>N1D z_7bg|;b(>&IWa@^3~5U*dPLH<8H%nu#yY}jo3f|Png`n2T-wUY?SgFddVu2bpHM03;DzM1*Ne6*d=7L^ifM9%lfVbXIFXKEvo*9YpBd!?MJDmHK&#)2Q1-yuiEy;KT_);nYY(q*56cpZ z61~l~Hj!AEIZC%9<}g!hXF5oUmJ@5~ZH8kNhDCQ=TNF7oj@V)?pJ*`}^{^$EzqHv}Avd8-= zgFPInkkm+rv>64o7|l^+&PUov{5U%isz`4_IfEG70D6OK1%w{q#H z(bh(GOMQD%%T{$WX}QVrq=RMEk4)|Hx#ZO=qxq3T;>+XxjAYU@Zr7@24ApoDOR-By zeP!-vYTcBK-qq+rt&=IWRU&zFjzOVVt)xL4$tB}R${tDgJR~Dk5RU_GZQUUY(^O5Hv)a+7Cq`nc zfsnQ&IZU_dfA7;%qqq`R0!Q_?_ zAF-12GH14vQb8r+&4#S{azH~`RT(5X+BngcSfNfAZ!WsfKB!lnh+dQ)z(c01nvzE>nZlS(tx|3?zY{_fc9I?e&(OpO^RcnW13wxfZnyK%AWw@I64y%iyr1qlSR4HqD z(&8TSX1&OKIGUB+>1m`Uts2>Gq#Jqr0Qx$+74MseWmz2do`Y|K7^!(XYPc^$d|5c;ia}} zx9Y@uli@2(wVA$uy0BW`($u)nY^gfS5kO!&5zsF{+ew$clHK^?oI_tpKX&SCDzJqw z<-VpW>^1@eDUkltMTTEe9N1461G$^DjX(y0=tubs3kN3M|L+l_&ngC{x7&NG{rpry z8xQfPEdxQ&UxzGXC$$~vm;R@2C^geB8yet7`v2@3;0}K!n@@bZ_~q>$BcZ?Tx8jE2)YmKn@XGMVIa2@F(bj*mh~NFw?1DD5tW!PpvZGbnDuFG?C7QRo z5ZOvN)3)r11yUzu&B@9#oBkEqL@50-J)6`4eZdyqifmyI^_F^yLy?7P8%ym#Y9`UN zBN9e`CbDUw?pV&;@tzvDno0E2Zl%=JY)W62h!Y20<-&;tsgcoOY*u=hxsxM9sm(+) zN!yJyWtRoUleNStkX0FLD6zA#L~IgDp0(zZHe)4|rWqPc+F7X+wpPlvTah*mX>lY> zzugSmk~Oq1Y!gilvgs>Cn#>Bt7EAwRVMKmVSgR(>tE zimhfXjV-cf)9;0`sUBvNo3sAxQn-gTqKiU_8c~VV0U61sMNX_6|5jg5>2WANqnX5W z+Zg2ZQ7_S%#D?PsBAQaOYR4V=@ippbzq4SMs`^+;sVU2Qdc}BUi=hnFawRWxV^V$v zCNenV!&vn>BCG=K71H8pXjDj9-gBb2I~a;K42lpog4TJtGWvw}TQW4Vl2Yc`Zfr;U~eZYDk4Lb(1@oj0l$MZ1h!)WEUY-TIh4R7UotcJi2A<(-AMLq3GUoNdZp6n?b zv7R)iEFiH>0{Y{=ZEC2mEB9rn^tJ6Tq<(jXt!$NOZ-)1FEY`PkIRo65(xH?JTScAH z3WM_N|K1+2FXYnPQ%ApqWm~<^1AEB&N#YB=Ec8>KY3A!D z(Ep>fJ4)z?OzhDa^JmCR^uv5_ue3dFs{b`IZLK@gf&t3g)3v+R7vOY_4Pb6Z>E{3v ze1)6d;$N<|P5YkV&sn7Y^BwBk@6#uXq;A=5p`>|j#-4~ni7+XOT)PtQYEnNicSoxC z@)9T2w3@xUrs=1c9w&7(He>y*Y>5u1tEJ@}R@`&TW7Ol7aHMXEz#!WgwYlihNEQ8u zv5)3snvh;M8fRr{N-ge=)`|YK>oalLVPQQj`xUaK^jWVV%dJ;j!%C!?CjAC=^76EX zv1Gy_9Q%#ONG7s0q&JhckywR9BmMssltVs^%5hBoD`xGH4a=hRuQbCfl)h#>E4wf; zJ`)YPwfVEPI39-MQ6f;A?oP~VAEl+YF~iutg!O7^;*o=my-ridNX0f_WHZdBQHs{= zVOVNciR4zbEr(Jng_7&4xsXl%8u{p}S6`iK?OW?brg}aSTb)ncyJ4ZEHO(1)sT$BP zm=&yf)pE)%S_>0Xdm>Ck(K99bGm-5n3dR0Bg4F4jqVKQFrg=q1mm|#rBNa}b*ZLF> zlMF+V`XOsf4l|jsd+Kj&5ZZgH6J2Hv*DclD z@^M!!bZ@B{SsqRryXK8A(LgV@nA^jNkTL>gPeX~;dljjkn(HKu$t|`L-PdrMda9jn z*+)wwdLTk6HM>(X$%#yVelWS8_T-hUb0>bO?CBa9>tyWDj2I2|ayGhmTdW-Jv4%yZ zDup67RVFr0;uj4`CGUHrW~Lum@jq&$*GLQ=z1e6D8L?(EBuir5&|Aqow(W^(xrs1s zBZ+N=$He~JqVAxtEpe80d5q>m+Tw8ZA0@CpyJRI=N&N@;)JC`dlZCcXLIe;%009IL zKmY**5cui|SRW?T=e);^-CGlzx9V7*&aC z@@`|kEcTJgTE5Ykx4NI5W7V)S@r1P?o``k97Ef3osN)Iib6h-OJ%5ZRhQu}#Pgwtk zTs&dP#S`jZlUJ{JVt8~ukr-hntPiE}RYsXLMw^M9&4l$?H{QlBW{q9V#8@+7Jy4Ig zA^&oH#SThka6S-z0&rIZ-i2^fGXeOqbi6S#m zY$i&~M5&pWW+uwaM7f!$FcZ_wM5URiG85HiqQ*?rnu$6yF~dyMn~4T9(P$=`%*0GH z@hvlPoS8V@Oq^gQPBatEW@46^m~AHJn28oMG1pAYGZXX8M5~!tU?xs76AR76A~Vru zCL(_lQ7lnYhMGTx%wNU?#3J6F)Q)*PDqSnTb_q;>Tv<1~YM^ znYhVJ{KQQB)J&{46E~ZQTg=3*X5wdN;x;pJyP3GdOx$TE?lKcUHxqZ8iF?e%y=LM* zGjYF}_=TBxz)U=7CVpuq9x@XTn~6uv#G_{7S7xHyOgv^L9yb$Dn2BGTi6_m(Q)c4n zSYqg!$QM^(o9*-1OD)KXEqzcfW>&mTi+`)d?`!e@Y4K>e?q$W=*;k7lTAZuJ)ml7T zi@UV=94)?5i&ty$16sURi(k{?ziRRSw0M+Unyo(Xqs2#PaY&1+w0M>lFV^DkYVj3X z{8KIdg%+>T;#al!JuUvP7LSy>ht=o3wfJx?K30n>wYXV}JGJ;bTD($=f1<_rYw^=s z{E8O8tHqyc@d&xWS$&?M#Z$ES7%iTz#V2a<$y$827JpxhZ_?uXwD>74ep!qEqQ(Ew z;^A@+w)#9?ix1P{qqVp~i%-zv4lQ1;#g}XGjaq!K7C))QFKO|gwfGY)-bvmjSbg40 zix1V}Z)$P579X$0?OGhx;uTtagBIVT#lP0#OcZ_whmwfLV}JVf3PT7CY87EjjVpca>E@k}jVsKwvb;!CvnM_PQB z7C)xN>$UhTE&f=G2g_SltIxY@@gyw{XmN=aH)-)nT6~5UU#!K~Yw?|0+^xkgYVn&| z{E-$9l6TcspLf&ZgS6PM#l>3OsKpDk_;fA4NQ-}{#dm1&ueA6DEq+6bKh)wJdE;;O zd7Ks>sKq`lF4E!#EpFA~)3o?PExt~RZ`a~SwfHw${JIu@pvBqp*~IGeSS_BY#a=C* zs>StMJYS2KY4HVG`~xk%O^YAV;^(#ak6OH0V*R}q@#mkbw~07>xenI36MM^U+cOqq z$EJJiWwQ+#Dl_A(8i_R`Pl*a-K9P!6vKrA3Zg!3=8=lfgrd6%&SIYLYhRJI7lvS;S z@u&N(bZpv>eJLIs%ih;nPtLOj53I#sGWR8~xouwa!uHOZ`7LdyNwH;(7}&ntwQuaW$heKZ5(6@P z_opM4e4sGCedpgNjT|x%V>RiNQ`fUJ{Yd?-+gMkv$ri6AiGAAQC3BE8YW+g5)vh&F z%@jA5m08>EX{y@!vR>+v-kX=bHY$)3`UktrDx{UKm|YUq!)`qV~WhDx6r{-P>(XV}nIiS}lAZ^vSN zJC`%SZ7Cf}sjyYlDXoy`cLBS!*u!v5S)}Us|11522WvaVLlD)cK7G3IS`F#bq&KPh zf~|>Dn6{>o-mcTJzWBv@8iGG;x?UlBtP4&Ms8ODE%Bjg0FDX@h>a}WFOOiy2$XgraJ%o^yxzCtFL!<)YlI2MrudJrQh3~ zK=e(?)^@A6-wJMR`>clm0$;j7NcDg_m>dR@QzsD8d8mAn`UX8YtwFW5Bo8N^PpiOV7 zS$moaJMv?iG`rKsRPQfA=^Eay#$CD&Ze6agSFG^XZD#vl%RaDO4D}VdkKGCNv3@^R zAK*Rf$DVHG<+@Kh817Ynx9wWgt4$tW=|5d$_@hVL$ES;A`M%G6;e(^^U1SXe z`cq(sy*%|N0&O9H00IagfWQ|bkl{~qC;l>luf`{eFT|3K>=*)(kC*l@v#(#qFA>RJ ze@UOK(|@*!|CKv^IuP+)h(5~N$_HA}O@_=Zr|+m)tJm`xADi{_c6!~R9x85!UFxB? z+vy^HUnqKDdgZVU*T8K$>$lIblkr-yE!}l&%gTLSH-2UA>xu(fe_(oL&(LSCqS0TP zjQhlgb%wtlOdX!S{#lC*_M{O&0D-Tez`)$(zD&rk-X49KnrwU91-9Yi%C>Ku77#!{ z38epRs0@GMFP|>9-$An@{AT}G@Hx{!4v378{X*Ni)r>7~Ys0DQ+f!59oZ+$1HWS7g zu=iI(w1dEx^?fMvJqd(*T9~}74^@*_+`;B}A2Of^Eb*`@!#|yo`cHUl^(TwOwgz;U z`rf$t7ONM}JIq!3zR~S`jqUbeCEVg)uq9Fh9ethO&BpgV7JNA)BxD`pUYw*>)g58l zJ%F9ee6flj(*fLizcy`;wx-5jC>r+Wrz5FSA686S@9Q?R!wrt@^8h7XCmDNR|A#+X7Qdd575`V| zH)-)ZTKsP<9=3O){+?QVh!!8E#bsK2oEEof@mX4YnHK+8i|^LrCnSzTSWnsFiCs*H zUCqQ;GhzLC;&>aonKgDd6W=fs)?d9lIC6CTf#8ux<>uuV9JA=yrehby{|3Wb_etRbui)L5~G&PMH!$)el|> zAbnxe7GR$VnR;kIu zZMMo$Hk`=mTBMX}F6AgG#Zu~|%#zYB z?_{%mL&`)cPASJonI@${N{f_EDQ8K!NXj)*ek$c2Dcw?@lk%FB zKTG*oN|u}#qos_OqRx?NGp0FxUa!mN4!AvKMP*L6D;RM2eSUvKWs%qCcKAK6pvzO( zP!R}v{Q;NT=LXLvo=kvOQK1cr4N~d($r)v2t z^NQSFzc1i%dVSv7qQZbb7;s7PI;ZF7`<)(Vz~KuzJ*Cyt{XvJv<8Zm1&bopUi9FJj z(-SPHEOGl?0k_-d@&;;VdUHpK|v1+oVTX9&gXJ@W#>E&cSGq6>6FhE z2zdPd>WU_}bU_xl1D>Y53ZLKU49XGm1*erYioa;dh8!h@1p%Mm6L1D(qa_XXL9f>< zyB=`5r#BRN?)3zv3l3?perlE9Ap^-Dl-9}{%cYk=8Cf2mr?G0PCn)1BC79=Ffob_RX!-0AsVU(gfq$Z+;F)t7o* zL3hwEn{wwC)p=z2i?%oD%WW){L+0^`x7$CXywK&7Ey=D2Tupg1ymGKz0XfWpn(}hl zg3~8s$KfokF7bJSF1JHY0pE<0VmXoI7|5vdmlov92oPt#&+RYFEtH*>!!C!*Q`uDP z4LDtXY1$pEFD;YtBm-VpgB3NkVk;V6x5HDKUoD$aXOcJIFDcHK6E5iX%76aClBs?f zc@D2Sw=0V)T~4RpEhoF%lV6(a45$I_mkrd{7Wo1)2;?jYI0_qOyvY^?#p^AuX!84< za_D^lm!mqb-sAVmj`{p<@6@JJbyBz-j-WR%qIr<1S*n{tcg z^l&+4Ca-?12U!UJ7H39l$Op_BvNvpo)R(e$|tS)bEeVNA=vTPD{s8NNP`CorR7s+?@npkL;T@``1+$~f}+ zy^exua= z=9S0=#3SdwjEbtte4o@(4FufA#%69=D^s zp;W+Rf8^2^EGv|uEPLdao%2qgRwa8QBhx7rXH?|NHB*L;j2TB>ev!I#2V?|0ot1eN zGAQK+Atz+8v8YBmAeV!nD-dj~t&;14^g(*-&8;k#gCI8(u?jYn*ZXC($%vD4yP>R5 zE>m(kIRY+E{nSRey!w3}IgI|if+m-ouhOE&AE+yDkh4tARJWY?B{Q1j!s(OY>JEBx zs|#g#%K%kcbq!Nxyvg>RvgQ1ea);9?$45?gxj>c&{p#kX26SaztzXVD*|2mscY1@I zMsi}ync*uaY>)+VACNtD<<%4jtxNjpaJuWuY6C%a}r@EwCpyU!Scg|qB?2jCrfLyel&XUrypj@VWZgV&LL-l zKXOnTr!@qfBIA|2zavk0g*8FB7%31(u&kuOBPW;)MCopB{S3L^$yf}^ ziCo{*sBXXkbs=-smgdR%;gY+s!x1R1tCX9W+|lLo;4Nv&_p0Nj?&H4traCo@WVkxr z-rSlBxmK#1kzX!CO{LPPv?b?_tD-4i?n82~ce^}+${90c&tx~{Ekm%ZtX6I#a)f13 zps27=jY_BNp4V4bF;xzlj5uk+UE0_r*Jh85ce(r)7BqTg;L13X4OCZD$(tKF>15A+ z)s3ZsB$o{#aOdZjc;z6-TL}THFK?1-U_h>FGH6R1t7Q1A`?J>)+ zOs|rQs$4nMu&Sudle@WGPULd#Y?ziW*Alr7$@R(G*i6y!D~o%TGb5iG4PkX7V#k|W`oQQsJpw*W#Vhr7BmSBA2jJn~M@ zIjvgmb8d&6>hkuiy1G!>5j44`_>Yj!KX*!*FGb!m+vFpH?Zv!Hes`;PMBU}BqsblRKS=GgO^ zXDyrB-aaRPq2yZI<`uUuYO&{bbS!LbZkydw+u6LdWufgHS?6T=jI~OpW2MMP23zh$ zp$iUvd*sR?)w6e7K779moi7y!E}Z!9GUp4<#e1AR_>tIcMV7 zip!nnyDoEAPFeQ2{oDs8)jH4J<%wNqpMCM0*_T%oue|!~Io`{Ml>DT)F}vx8i9?Ue zyY#U0*DXHKRnxGny!5R#jXPa@*ScNKIk@HW!;aXm@m%%UGx7?JmPttMVq|C4oa)~m zQudazzm$WeIHY`2%2X+3QtG9s&#cEwnJwicDNCiO{mgCcw9jsDo739W+TPaEIkk1M zVzF1Z&#CF^Y+tyrrE|K%t8HJF*LqS*dF#TH?3L}O%xUQ?U$U^PxvjeSl;(vUt+IY~ z%iPY^c`cpIUG1F(t<4=Ror`OyNMoH{r7g{Ei>G!jY3-_MYo2{l%}J80ZCTi|v~{s; zv$mzJOZuT+W%4;xy>?XtL|*DRs#pEf_Z@S>HBY|u*F%5U{LPmSJbCnr-EZ$2SNqD^ zMF)O*+)MwguXwa$;dT4nRy#Iy=%o{nIsCupbiMNG#-mTIsXwuL(mgA#eCw^}&UyTp z-Af;PeCBuOAN$(XUHc!nX4!?4PJjKjx!2V?=RJPF<9p5VyglQmrxF`O;w$?(UcUC$ zn|3M5eRJG~sdN5(_F>KgR=nu_&rYjypBn1_NkQXt?9ByB0td~TI^p@{)1G~~>7GlUJNDh<#@KJWcAaFaO?@0aAN$A2&z{iIbY$(~{f_!h z-O)8C-;jOt-|wCM$@eCWnf$+-CXJi$jW_ceetpHIFF*K(_xju4JIA}^oxmHbkGrv< z(RSf?p7^Tawyc`)+s?r^&$;){|C+O6)*GvvA3oupt3SQizr0}c)8jVHyKK)Z zZ>q{asN(3oE_?X7llN-dIcJLNp-axFId0bZPh@A6sbMT1qMK!iJJgSk6Q6H+x8sD- zcXUP;^zfr{g?f$?xfZDNVx<&y-BIVT8dm48ELeYYSxNChWA8b6>;WxJ6IUEC_M967 z2dykP_>HTl%v$-qb!*Q5{^U!pyT4>FZ|j8bpEJIw^`^Vas#g}2tUPDh%Bl@(${NSH zX8JBapfK7+*^pd%qSu`fGTlXr+HOr)|_>cXjY6$P@$_ObdQ zuXZWYSKFLT9UE;~n?`P0zjkBRvpFN@bv)BDuO+8Fd(PTTZKK-OWUn8-Df^ib1J&YibmQ}&4UHfdr_%eu8KIXN3Q*hbG=w`Sh6>qgr;hIeGI8@cA$oH^?^ zj-Hb>H)q4Twd+TBjNXu)wRS}N#yQXAj9k06ZSwvHCt@SOE(chbB@#gE5DvUFwax~iV5MXswOtDoIz@-0ukcJh0VU47Gp#kXBi zal!z2lo4I26j4SUxaIe)9%9fsZ zaOKf=HTb3zq{_g3x0jbh!-zf^Zx5U+ke5WBi9|Y(?2dd@4Wjy zte!D@>3PTa);}=+koTL$K5^~p@9kv^NBbMOq(&~Tk%6O*iF)?8n-ta0eWXm1qOPmn zSP9BBB&ASFsg&tb>ZBYerA3OmD=Co0QkF?MQ_49~ltFdt?D@@|bIMw)ME4ysTQ1P4i`GeM@vI>FVYMEsN)uHZN*fTysiuS8HWw>$2uZ zrLxx2_W5l!&7IBjBlG3z&eq(vsHIi@FK_N#mfN{tNn2%eS7+<&*5=Cg_64%3isoF| zkaWj3QXL3+b?-lacBMQdn-ZP1**^K@;f*62M*Mu-G1>n&<(@aQ!{@$z-+Q)i4SwPW zkA)xKJm;CgC;tA}^UfSJ>%n#Hle+hO_pjdFSFOM5Z&zl|-uH(e`@$dlhyH5tvy&cj zPR;u7tgF`S>8NZdxjkp(dB=XRN9g=_a>r+{y7ZiH{WJTak<6rici$Ar^o%qPfPlW$^-~KCxey`!P_uSb9^Y5w{yXy0^E}cC5g)<-7 zdv4bFLAxzGq;j3@-)n~W*0$Y#TX^ZAb)M1PgI_=NI@{f@+rurkgTH^plwm9Xu;96e zv#z;d-1zLuU;N`QYld7o`rQ9L7(TJ!&L?uhXaD(`k8GWTI(?1ZPh7U|ict?=bkym0 zEcd+V@7UX(b@nkYW;LAAY}+^V!vpSKH`eyn$HyL-Gv%e<{d0zGqWzY=hTc2;yc)Z$ zu5QBY9EbnPxgQLgd%!U#ykfif_s8y^)xEZ{X?50_kIvm}59jBX{%Ckt!%1%xEPo+) z^Ywdep0xPIowIj8vtja1;ZG_adHNft9{Z8^-YiGS`IF~{4zuU}d{oEo?`|j$H=bQJ zdH2gEeCj+u`wweoJap)qPJP=ped9F?PUwE8{QJX3jM=#2*`>B&znQoB zz|EHqnmaA)*}JFxx48R`O@|)8(*x`7x$=;l`+pLeH{`^Fessq0VYj_^(Eb0g1#i9Q zl0(LR|DW*DgKzloslP4j-sjFs2Jba&x7{jRZN)!%e#*$#Z}hCR4eR>h--c9VE#7P5 zuSPnKc>1nYgDc-<<~A?%U*e%x140zv00lV~+dRrYCKo-A_4keE6*o z3oo=i5STmS5ZlsUT+}dT)v~{hd?4%b(fP00LzBjjzH0bL`(yV%n6>$#jn|I14gP4} z$F{(j|M}kxy*ujSyZ0*2dhA!f&t4e(?3$A9h4-zm95n8^t1i8AdDW{2zp&>E6Q60> zGy79p$3-{W7F?CPV%N!+oNm85>&}Y3uH7Yc?I*{rwH@Yoa_#i2tS1+&9d5hi!~1@e zl~;D!oxif(_SCg|4SwbEJKxon3q2>y!P13$u@W z^{V-~wn3G?!Q)Te_q3*e+a5jQN2|XTetE{^zmBk<`5(s~%MbY8UG1H2>YXN++&1~#<(uE`yy$=z9-n#5#_Wx& zHdfygzWX=V>^3Sq^UiAywcY1_YNBt|CGTH!|KLI4U;Xhj+auQ;KE-3daP65t&MN)K zTT3=s*BY6I&iT*1u}hI{>^mEFk1d*T_4Jvx*YED~53_ZgaLl4C+c)1)SLObCU0tkN z`dnUe-Iv>W^t&E4m0Oif9?9F<)Qr5{lM+58avMyd*dhdu_?A5hBdc#n+lE|&Wy8f#hLhKe4x%a4hN93*&zjwrLD(W5) zzm2GSQzS1NiQTy(vvPw{-(_`4u}i;{m%=F7zvMOF<%?#@cO?;1zvR^W{GUj9Udm9Z zuhyG6Dbn*vhuJ4ZcSJ>zJ5(g4?sJol3$>iuyr^Sg%i*)z7tNfcc1SIZ?w*R&&PJx{ zHWyhhauYRIR6DMgS`g8ia?>;`Ml{q?MWdz>Cl$op%oXBFs*Wm(PEBKy@lf5F&6O?2 z=z5X&1o8TU9ZlC53_@u8RLWd2l&@kUpOw`66g9?NQmoH^l@d2enIlDwV|B$;<4BF6 zE2XTGa;ubkr93KSjg<9LekbJ}DVwFJ?a1R&c{nUjt>k%%JSCGS*iNrJ3X(5g52DtQ*79ww+qwz5c`+Q{Qsd0r(?&*j?<`4&N*ZA)EwMktRpeDXL!J$aW0cIttZ zJT!2~lSP+25R!+&@?=y!!ihYYmdD2OtWBOa%VQ~dt|^ZbM6M^1b7-Vw)U`wPLnR{psJrFX&p*$J^g~@c z)%mCDmB>f(=DM1E`83}=Pv)PHGEs`^hgx^Pv958|OV4e*^_kaxC-wJ~ROGrfMy6_J zKPhS~xuqN_B~MC;6ty9Both;@wOi1#aDHo^yisjl*xuPvyiA@9G|y_DlRv+ut*yDW zxP8guuJ*QymQ(DF?VTq{Z)P{oYcFeGY|m|*CyxLYSI%#5X=^>TMyj7CrD%z?blUv( zC2~SVrZw{=nBT14P0M%Sk)JkO`0h_fEYaVmM)H?^7CiK7+sVHgU$fhZkAAfD{vj&| z{o=?!gr9R?`ugY>=3O~=ysfxv)BaO!p`EKf7`w;wu3fITol^ANTStce@5EmY8Qw7U z=GkrG7j5AW_U~9Rcu~%f+Mn)r)h64YUOi^r;8Uk&-F$8KC#@}io@_hB^M37!RVUr| z)}FRU1Fv-q>aOyhdt~-^%U0fdMtAu;Kj~EdJ-<@#H$B&9DFO%}fB*srAb<9uE4mte!^Se&_=jXAP)#t2OkR6-uYAs-j00IagfB*sr zAb T: """ parse model out text to prompt define response @@ -129,8 +143,8 @@ class BaseOutputParser(ABC): """ cleaned_output = model_out_text.rstrip() - # if "```json" in cleaned_output: - # _, cleaned_output = cleaned_output.split("```json") + if "```json" in cleaned_output: + _, cleaned_output = cleaned_output.split("```json") # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): @@ -142,18 +156,12 @@ class BaseOutputParser(ABC): cleaned_output = cleaned_output.strip() if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): logger.info("illegal json processing") - json_pattern = r"{(.+?)}" - m = re.search(json_pattern, cleaned_output) - if m: - cleaned_output = m.group(0) - else: - raise ValueError("model server out not fllow the prompt!") + cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", "") - .replace("\\n", "") - .replace("\\", "") - .replace("\\", "") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index e25a17340..f91af967c 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -78,11 +78,8 @@ class ChatWithPlugin(BaseChat): super().chat_show() def __list_to_prompt_str(self, list: List) -> str: - if list: - separator = "\n" - return separator.join(list) - else: - return "" + return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) + def generate(self, p) -> str: return super().generate(p) diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index 7b7abbc09..6f67bde38 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -14,20 +14,24 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class PluginAction(NamedTuple): command: Dict speak: str - reasoning: str thoughts: str class PluginChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: - response = json.loads(super().parse_prompt_response(model_out_text)) - command, thoughts, speak, reasoning = ( + clean_json_str = super().parse_prompt_response(model_out_text) + print(clean_json_str) + try: + response = json.loads(clean_json_str) + except Exception as e: + raise ValueError("model server out not fllow the prompt!") + + command, thoughts, speak = ( response["command"], response["thoughts"], - response["speak"], - response["reasoning"], + response["speak"] ) - return PluginAction(command, speak, reasoning, thoughts) + return PluginAction(command, speak, thoughts) def parse_view_response(self, speak, data) -> str: ### tool out data to table view diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 98ba1652c..eebb8de94 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -10,7 +10,7 @@ from pilot.scene.chat_execution.out_parser import PluginChatOutputParser CFG = Config() -PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.Play to your strengths as an LLM and pursue simple strategies with no legal complications.""" +PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" PROMPT_SUFFIX = """ Goals: @@ -20,25 +20,22 @@ Goals: _DEFAULT_TEMPLATE = """ Constraints: - Exclusively use the commands listed in double quotes e.g. "command name" - Reflect on past decisions and strategies to refine your approach. - Constructively self-criticize your big-picture behavior constantly. - {constraints} +0.Exclusively use the commands listed in double quotes e.g. "command name" +{constraints} Commands: - {commands_infos} +{commands_infos} """ -PROMPT_RESPONSE = """You must respond in JSON format as following format: -{response} - +PROMPT_RESPONSE = """ +Please response strictly according to the following json format: + {response} Ensure the response is correct json and can be parsed by Python json.loads """ RESPONSE_FORMAT = { "thoughts": "thought text", - "reasoning": "reasoning", "speak": "thoughts summary to say to user", "command": {"name": "command name", "args": {"arg name": "value"}}, } diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index 909f8bf4b..55f525988 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -4,6 +4,8 @@ import sys from dotenv import load_dotenv + + if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): print("Setting random seed to 42") random.seed(42) diff --git a/pilot/server/bar_chart.html b/pilot/server/bar_chart.html deleted file mode 100644 index 8c4d0c714..000000000 --- a/pilot/server/bar_chart.html +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/pilot/server/db-gpt-test.db b/pilot/server/db-gpt-test.db deleted file mode 100644 index 929805035e1c69ca980b5e0750d0d4b7cc3c7f3e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI#p$&jA5Cu?3V1#;hU<7bD5+-0_W+1i903=|FLtgUE{JYfjURCRDU-L1iaT%t* zdaAjjdwW5E009C72oNAZfB*pk1PH_zXj8ev`Kj{MM1TMR0t5&UAV7cs0RjXFL=^D< ZkN9ftOn?9Z0t5&UAV7cs0Rja630#2eCE)-7 diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py index a2a3d2506..b82a09bc4 100644 --- a/tests/unit/test_plugins.py +++ b/tests/unit/test_plugins.py @@ -3,6 +3,7 @@ import os import pytest from pilot.configs.config import Config +from pilot.configs.model_config import PLUGINS_DIR from pilot.plugins import ( denylist_allowlist_check, inspect_zip_for_modules,