diff --git a/.env.template b/.env.template
index 2ed7932f3..b06ef215e 100644
--- a/.env.template
+++ b/.env.template
@@ -101,8 +101,8 @@ LANGUAGE=en
#*******************************************************************#
# ** PROXY_SERVER
#*******************************************************************#
-PROXY_API_KEY=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
-PROXY_SERVER_URL=http://127.0.0.1:3000/api/openai/v1/chat/completions
+PROXY_API_KEY=
+PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
#*******************************************************************#
diff --git a/README.zh.md b/README.zh.md
index 3db20702c..06260c9b1 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -260,7 +260,7 @@ Run the Python interpreter and type the commands:
这是一个用于数据库的复杂且创新的工具, 我们的项目也在紧急的开发当中, 会陆续发布一些新的feature。如在使用当中有任何具体问题, 优先在项目下提issue, 如有需要, 请联系如下微信,我会尽力提供帮助,同时也非常欢迎大家参与到项目建设中。
-
+
## Licence
diff --git a/assets/DB_GPT_wechat.png b/assets/DB_GPT_wechat.png
deleted file mode 100644
index a1d7f7558..000000000
Binary files a/assets/DB_GPT_wechat.png and /dev/null differ
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
index 77ddbc89a..0a0c3c003 100644
Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 33d96f1c2..4b8b85a62 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
+ "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"proxyllm": "proxyllm",
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 0ff368c70..64d3617bf 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -3,7 +3,7 @@
from functools import cache
from typing import List
-from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
+from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from pilot.configs.model_config import DEVICE
@@ -86,7 +86,15 @@ class ChatGLMAdapater(BaseLLMAdaper):
class GuanacoAdapter(BaseLLMAdaper):
"""TODO Support guanaco"""
- pass
+ def match(self, model_path: str):
+ return "guanaco" in model_path
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs
+ )
+ return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
@@ -143,6 +151,7 @@ class ProxyllmAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
+register_llm_model_adapters(GuanacoAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later
diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py
new file mode 100644
index 000000000..03f2d1687
--- /dev/null
+++ b/pilot/model/guanaco_llm.py
@@ -0,0 +1,65 @@
+import torch
+from threading import Thread
+from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
+from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
+
+def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
+ """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
+
+ print(params)
+ stop = params.get("stop", "###")
+ messages = params["prompt"]
+
+
+ hist = []
+ for i in range(1, len(messages) - 2, 2):
+ hist.append(
+ (
+ messages[i].split(ROLE_USER + ":")[1],
+ messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
+ )
+ )
+
+ query = messages[-2].split(ROLE_USER + ":")[1]
+ print("Query Message: ", query)
+
+ input_ids = tokenizer(query, return_tensors="pt").input_ids
+ input_ids = input_ids.to(model.device)
+
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
+ stop_token_ids = [0]
+ class StopOnTokens(StoppingCriteria):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ for stop_id in stop_token_ids:
+ if input_ids[0][-1] == stop_id:
+ return True
+ return False
+
+ stop = StopOnTokens()
+
+ generate_kwargs = dict(
+ input_ids=input_ids,
+ max_new_tokens=512,
+ temperature=1.0,
+ do_sample=True,
+ top_k=1,
+ streamer=streamer,
+ repetition_penalty=1.7,
+ stopping_criteria=StoppingCriteriaList([stop])
+ )
+
+
+ t1 = Thread(target=model.generate, kwargs=generate_kwargs)
+ t1.start()
+
+ generator = model.generate(**generate_kwargs)
+ for output in generator:
+ # new_tokens = len(output) - len(input_ids[0])
+ decoded_output = tokenizer.decode(output)
+ if output[-1] in [tokenizer.eos_token_id]:
+ break
+
+ out = decoded_output.split("### Response:")[-1].strip()
+
+ yield out
+
diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py
index 118d45f97..359d478f8 100644
--- a/pilot/model/llm_utils.py
+++ b/pilot/model/llm_utils.py
@@ -1,6 +1,11 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
+import traceback
+from queue import Queue
+from threading import Thread
+import transformers
+
from typing import List, Optional
from pilot.configs.config import Config
@@ -47,3 +52,66 @@ def create_chat_completion(
response = None
# TODO impl this use vicuna server api
+
+
+class Stream(transformers.StoppingCriteria):
+ def __init__(self, callback_func=None):
+ self.callback_func = callback_func
+
+ def __call__(self, input_ids, scores) -> bool:
+ if self.callback_func is not None:
+ self.callback_func(input_ids[0])
+ return False
+
+
+class Iteratorize:
+
+ """
+ Transforms a function that takes a callback
+ into a lazy iterator (generator).
+ """
+
+ def __init__(self, func, kwargs={}, callback=None):
+ self.mfunc = func
+ self.c_callback = callback
+ self.q = Queue()
+ self.sentinel = object()
+ self.kwargs = kwargs
+ self.stop_now = False
+
+ def _callback(val):
+ if self.stop_now:
+ raise ValueError
+ self.q.put(val)
+
+ def gentask():
+ try:
+ ret = self.mfunc(callback=_callback, **self.kwargs)
+ except ValueError:
+ pass
+ except:
+ traceback.print_exc()
+ pass
+
+ self.q.put(self.sentinel)
+ if self.c_callback:
+ self.c_callback(ret)
+
+ self.thread = Thread(target=gentask)
+ self.thread.start()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ obj = self.q.get(True, None)
+ if obj is self.sentinel:
+ raise StopIteration
+ else:
+ return obj
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop_now = True
diff --git a/pilot/model/loader.py b/pilot/model/loader.py
index a3d443da8..9fe6207c1 100644
--- a/pilot/model/loader.py
+++ b/pilot/model/loader.py
@@ -113,7 +113,11 @@ class ModelLoader(metaclass=Singleton):
or self.device == "mps"
and tokenizer
):
- model.to(self.device)
+ # 4-bit not support this
+ try:
+ model.to(self.device)
+ except ValueError:
+ pass
if debug:
print(model)
diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py
index eb992c8eb..7a61c788f 100644
--- a/pilot/out_parser/base.py
+++ b/pilot/out_parser/base.py
@@ -108,7 +108,7 @@ class BaseOutputParser(ABC):
if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep)
else:
- return self._parse_model_stream_resp(response, self.sep, skip_echo_len)
+ return self._parse_model_stream_resp(response, self.sep)
def parse_prompt_response(self, model_out_text) -> T:
"""
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index 17d2f95a8..39737112b 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -91,8 +91,9 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path
def get_generate_stream_func(self):
- # TODO
- pass
+ from pilot.model.guanaco_llm import guanaco_generate_output
+
+ return guanaco_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
@@ -107,6 +108,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
+register_llm_model_chat_adapter(GuanacoChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)