mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 14:06:43 +00:00
Merge remote-tracking branch 'origin/main' into dev
This commit is contained in:
commit
31d457cfd5
@ -41,7 +41,7 @@ MAX_POSITION_EMBEDDINGS=4096
|
||||
#** DATABASE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
LOCAL_DB_USER=root
|
||||
LOCAL_DB_PASSWORD=aa12345678
|
||||
LOCAL_DB_PASSWORD=aa123456
|
||||
LOCAL_DB_HOST=127.0.0.1
|
||||
LOCAL_DB_PORT=3306
|
||||
|
||||
|
@ -123,7 +123,7 @@ As our project has the ability to achieve ChatGPT performance of over 85%, there
|
||||
This project relies on a local MySQL database service, which you need to install locally. We recommend using Docker for installation.
|
||||
|
||||
```bash
|
||||
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
||||
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
|
||||
```
|
||||
We use [Chroma embedding database](https://github.com/chroma-core/chroma) as the default for our vector database, so there is no need for special installation. If you choose to connect to other databases, you can follow our tutorial for installation and configuration.
|
||||
For the entire installation process of DB-GPT, we use the miniconda3 virtual environment. Create a virtual environment and install the Python dependencies.
|
||||
|
@ -113,7 +113,7 @@ TODO: 在终端展示上,我们将提供多端产品界面。包括PC、手机
|
||||
|
||||
本项目依赖一个本地的 MySQL 数据库服务,你需要本地安装,推荐直接使用 Docker 安装。
|
||||
```
|
||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
|
||||
```
|
||||
向量数据库我们默认使用的是Chroma内存数据库,所以无需特殊安装,如果有需要连接其他的同学,可以按照我们的教程进行安装配置。整个DB-GPT的安装过程,我们使用的是miniconda3的虚拟环境。创建虚拟环境,并安装python依赖包
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
@ -57,3 +56,55 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
out = decoded_output.split("### Response:")[-1].strip()
|
||||
|
||||
yield out
|
||||
|
||||
|
||||
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
tokenizer.bos_token_id = 1
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
max_new_tokens = params.get("max_new_tokens", 512)
|
||||
temerature = params.get("temperature", 1.0)
|
||||
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[-1][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
stop = StopOnTokens()
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temerature,
|
||||
do_sample=True,
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
model.generate(**generate_kwargs)
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
||||
|
@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
}
|
||||
|
||||
print(payloads)
|
||||
print(headers)
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
)
|
||||
|
||||
text = ""
|
||||
print("====================================res================")
|
||||
print(res)
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
|
@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton):
|
||||
model.to(self.device)
|
||||
except ValueError:
|
||||
pass
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
@ -56,8 +56,11 @@ class BaseOutputParser(ABC):
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
||||
output = data["text"][skip_echo_len:].replace("<s>", "").strip()
|
||||
# NO stream output
|
||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||
|
||||
# stream out output
|
||||
output = data["text"][11:].replace("<s>", "").strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
|
@ -101,9 +101,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
|
||||
return guanaco_generate_output
|
||||
return guanaco_generate_stream
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
|
Loading…
Reference in New Issue
Block a user