mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
Fix ColossalEval (#5186)
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -116,10 +116,10 @@ class HuggingFaceModel(BaseModel):
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
else:
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
@@ -586,11 +586,10 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
else:
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
Reference in New Issue
Block a user