From acf9dbbd8218a803a2d6bd81f4827ebff07a2414 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 29 Apr 2023 21:50:47 +0800 Subject: [PATCH] fix problem --- pilot/model/loader.py | 4 ++-- pilot/server/chatbot.py | 2 +- pilot/server/sqlgpt.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 979b2bb89..7b78ebe8c 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -29,8 +29,8 @@ class ModerLoader: if debug: print(model) - if self.device == "cuda": - model.to(self.device) + # if self.device == "cuda": + # model.to(self.device) return model, tokenizer diff --git a/pilot/server/chatbot.py b/pilot/server/chatbot.py index 5796a8c66..5e0ad9294 100644 --- a/pilot/server/chatbot.py +++ b/pilot/server/chatbot.py @@ -6,7 +6,7 @@ import json import time from urllib.parse import urljoin import gradio as gr -from configs.model_config import * +from pilot.configs.model_config import * vicuna_base_uri = "http://192.168.31.114:21002/" vicuna_stream_path = "worker_generate_stream" vicuna_status_path = "worker_get_status" diff --git a/pilot/server/sqlgpt.py b/pilot/server/sqlgpt.py index 6dbf1bfc1..773de8611 100644 --- a/pilot/server/sqlgpt.py +++ b/pilot/server/sqlgpt.py @@ -5,7 +5,7 @@ import json import torch import gradio as gr -from fastchat.serve.inference import generate_stream, compress_module +from fastchat.serve.inference import generate_stream from transformers import AutoTokenizer, AutoModelForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" @@ -20,7 +20,6 @@ model = AutoModelForCausalLM.from_pretrained( ) def generate(prompt): - compress_module(model, device) model.to(device) print(model, tokenizer) params = {