mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.
This commit is contained in:
@@ -1,23 +1,19 @@
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict, Callable
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
import inspect
|
||||
import math
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
@@ -195,7 +191,6 @@ class WorkerBase(ABC):
|
||||
if isinstance(output, Future):
|
||||
output = output.wait()
|
||||
|
||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
||||
output_work_item.refcount += 1
|
||||
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
@@ -250,9 +245,6 @@ class WorkerBase(ABC):
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
@@ -273,9 +265,6 @@ class WorkerBase(ABC):
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, False)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
@@ -297,23 +286,14 @@ class WorkerBase(ABC):
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
@@ -328,10 +308,6 @@ class WorkerBase(ABC):
|
||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
for i in range(consumer_num):
|
||||
consumer_stage_id = self.consumer_stage_ids[i]
|
||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
@@ -342,17 +318,11 @@ class WorkerBase(ABC):
|
||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, False)
|
||||
|
||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
@@ -406,11 +376,6 @@ class WorkerBase(ABC):
|
||||
is_first_stage = self.is_first_stage()
|
||||
is_last_stage = self.is_last_stage()
|
||||
|
||||
# if self.pp_rank == 0:
|
||||
# print(
|
||||
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
||||
# )
|
||||
|
||||
if phase == Phase.FORWARD:
|
||||
# remind its consumer to get data before forward
|
||||
if not is_last_stage:
|
||||
@@ -470,8 +435,6 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
||||
if is_last_stage and self.criterion:
|
||||
with self.label_lock:
|
||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||
@@ -539,10 +502,6 @@ class WorkerBase(ABC):
|
||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
|
||||
# for input_node in stage_input_args:
|
||||
# if isinstance(input_node, torch.Tensor):
|
||||
# consume_result.append(input_node.grad)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
||||
@@ -593,11 +552,6 @@ class WorkerBase(ABC):
|
||||
with self.work_list_condition_lock:
|
||||
work_item = self.work_list.pop(work_item_key)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
# assert work_item_key not in self.output_list
|
||||
self.output_list[work_item_key] = work_item
|
||||
@@ -605,11 +559,6 @@ class WorkerBase(ABC):
|
||||
|
||||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
||||
'work loop', 'green')
|
||||
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
# if is last step in one batch reset context and do step
|
||||
|
@@ -1,13 +1,12 @@
|
||||
from typing import List, Callable, Dict
|
||||
import threading
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
|
@@ -1,25 +1,15 @@
|
||||
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
import argparse
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.futures import Future
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from colorama import Back, Style
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
|
||||
# config for debug and test
|
||||
use_color_debug = False
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
|
Reference in New Issue
Block a user