Merge remote-tracking branch 'origin/main' into dev

This commit is contained in:
aries-ckt 2023-06-05 10:00:48 +08:00
commit 31d457cfd5
8 changed files with 64 additions and 12 deletions

View File

@ -41,7 +41,7 @@ MAX_POSITION_EMBEDDINGS=4096
#** DATABASE SETTINGS **#
#*******************************************************************#
LOCAL_DB_USER=root
LOCAL_DB_PASSWORD=aa12345678
LOCAL_DB_PASSWORD=aa123456
LOCAL_DB_HOST=127.0.0.1
LOCAL_DB_PORT=3306

View File

@ -123,7 +123,7 @@ As our project has the ability to achieve ChatGPT performance of over 85%, there
This project relies on a local MySQL database service, which you need to install locally. We recommend using Docker for installation.
```bash
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
$ docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
```
We use [Chroma embedding database](https://github.com/chroma-core/chroma) as the default for our vector database, so there is no need for special installation. If you choose to connect to other databases, you can follow our tutorial for installation and configuration.
For the entire installation process of DB-GPT, we use the miniconda3 virtual environment. Create a virtual environment and install the Python dependencies.

View File

@ -113,7 +113,7 @@ TODO: 在终端展示上我们将提供多端产品界面。包括PC、手机
本项目依赖一个本地的 MySQL 数据库服务,你需要本地安装,推荐直接使用 Docker 安装。
```
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa123456 -dit mysql:latest
```
向量数据库我们默认使用的是Chroma内存数据库所以无需特殊安装如果有需要连接其他的同学可以按照我们的教程进行安装配置。整个DB-GPT的安装过程我们使用的是miniconda3的虚拟环境。创建虚拟环境并安装python依赖包

View File

@ -1,5 +1,4 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
@ -57,3 +56,55 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip()
yield out
def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
tokenizer.bos_token_id = 1
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
max_new_tokens = params.get("max_new_tokens", 512)
temerature = params.get("temperature", 1.0)
query = prompt
print("Query Message: ", query)
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
tokenizer.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[-1][-1] == stop_id:
return True
return False
stop = StopOnTokens()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temerature,
do_sample=True,
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]),
)
model.generate(**generate_kwargs)
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"max_tokens": params.get("max_new_tokens"),
}
print(payloads)
print(headers)
res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
text = ""
print("====================================res================")
print(res)
for line in res.iter_lines():
if line:
decoded_line = line.decode("utf-8")

View File

@ -118,6 +118,8 @@ class ModelLoader(metaclass=Singleton):
model.to(self.device)
except ValueError:
pass
except AttributeError:
pass
if debug:
print(model)

View File

@ -56,8 +56,11 @@ class BaseOutputParser(ABC):
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
output = data["text"][skip_echo_len:].replace("<s>", "").strip()
# NO stream output
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
# stream out output
output = data["text"][11:].replace("<s>", "").strip()
else:
output = data["text"].strip()

View File

@ -101,9 +101,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_output
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
return guanaco_generate_output
return guanaco_generate_stream
class ProxyllmChatAdapter(BaseChatAdpter):