fix problem

This commit is contained in:
csunny 2023-04-29 21:50:47 +08:00
parent 0767537606
commit acf9dbbd82
3 changed files with 4 additions and 5 deletions

View File

@ -29,8 +29,8 @@ class ModerLoader:
if debug:
print(model)
if self.device == "cuda":
model.to(self.device)
# if self.device == "cuda":
# model.to(self.device)
return model, tokenizer

View File

@ -6,7 +6,7 @@ import json
import time
from urllib.parse import urljoin
import gradio as gr
from configs.model_config import *
from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"

View File

@ -5,7 +5,7 @@
import json
import torch
import gradio as gr
from fastchat.serve.inference import generate_stream, compress_module
from fastchat.serve.inference import generate_stream
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -20,7 +20,6 @@ model = AutoModelForCausalLM.from_pretrained(
)
def generate(prompt):
compress_module(model, device)
model.to(device)
print(model, tokenizer)
params = {