mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug info
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -84,9 +83,8 @@ class CheckpointIO(ABC):
|
||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
# should be done offline via our CLI
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
ckpt_path = Path(checkpoint)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
||||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
|
Reference in New Issue
Block a user