Merge branch 'dbgpt_doc' of https://github.com/csunny/DB-GPT into llm_fxp

This commit is contained in:
csunny 2023-06-14 10:25:07 +08:00
commit 14dc2e4ce9
2 changed files with 4 additions and 6 deletions

View File

@ -108,7 +108,7 @@ class GuanacoAdapter(BaseLLMAdaper):
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = LlamaTokenizer.from_pretrained(model_path) tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, load_in_4bit=True, device_map={"": 0}, **from_pretrained_kwargs model_path, load_in_4bit=True, **from_pretrained_kwargs
) )
return model, tokenizer return model, tokenizer
@ -127,7 +127,6 @@ class FalconAdapater(BaseLLMAdaper):
model_path, model_path,
load_in_4bit=True, # quantize load_in_4bit=True, # quantize
quantization_config=bnb_config, quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True, trust_remote_code=True,
**from_pretrained_kwagrs, **from_pretrained_kwagrs,
) )
@ -135,7 +134,6 @@ class FalconAdapater(BaseLLMAdaper):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
trust_remote_code=True, trust_remote_code=True,
device_map={"": 0},
**from_pretrained_kwagrs, **from_pretrained_kwagrs,
) )
return model, tokenizer return model, tokenizer

View File

@ -73,12 +73,12 @@ class ModelLoader(metaclass=Singleton):
elif self.device == "cuda": elif self.device == "cuda":
kwargs = {"torch_dtype": torch.float16} kwargs = {"torch_dtype": torch.float16}
num_gpus = int(num_gpus) num_gpus = torch.cuda.device_count()
if num_gpus != 1: if num_gpus != 1:
kwargs["device_map"] = "auto" kwargs["device_map"] = "auto"
if max_gpu_memory is None: # if max_gpu_memory is None:
kwargs["device_map"] = "sequential" # kwargs["device_map"] = "sequential"
available_gpu_memory = get_gpu_memory(num_gpus) available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = { kwargs["max_memory"] = {