mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||
"""
|
||||
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
|
||||
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
|
||||
if use_safetensors:
|
||||
assert is_safetensors_available(), "safetensors is not available."
|
||||
|
Reference in New Issue
Block a user