mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
add vicuna embedding
This commit is contained in:
@@ -2,8 +2,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from pilot.utils import get_gpu_memory
|
||||
from fastchat.serve.inference import compress_module
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
@@ -28,12 +26,12 @@ class ModerLoader:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
|
||||
|
||||
if load_8bit:
|
||||
compress_module(model, self.device)
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
||||
if self.device == "cuda":
|
||||
model.to(self.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user