add vicuna embedding

This commit is contained in:
csunny
2023-04-29 18:28:42 +08:00
parent 4b2e3bf59a
commit 0767537606
8 changed files with 458 additions and 251 deletions

View File

@@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-
import torch
from pilot.utils import get_gpu_memory
from fastchat.serve.inference import compress_module
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
@@ -28,12 +26,12 @@ class ModerLoader:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs)
if load_8bit:
compress_module(model, self.device)
if debug:
print(model)
if self.device == "cuda":
model.to(self.device)
return model, tokenizer