mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
fix diff device in some partition
This commit is contained in:
parent
3a15b20421
commit
9ae9e74017
@ -789,6 +789,8 @@ class WorkerBase(ABC):
|
||||
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
|
||||
process_types=torch.device) # change devices from last stage to current device
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user