[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
This commit is contained in:
Super Daniel
2022-10-18 10:44:23 +08:00
committed by GitHub
parent e8d8eda5e7
commit 393f594051
32 changed files with 351 additions and 310 deletions

View File

@@ -1,23 +1,19 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import math
import inspect
import math
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
class Phase(Enum):
@@ -195,7 +191,6 @@ class WorkerBase(ABC):
if isinstance(output, Future):
output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
@@ -250,9 +245,6 @@ class WorkerBase(ABC):
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
@@ -273,9 +265,6 @@ class WorkerBase(ABC):
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -297,23 +286,14 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
@@ -328,10 +308,6 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
@@ -342,17 +318,11 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def _get_producer_consumer(self) -> None:
@@ -406,11 +376,6 @@ class WorkerBase(ABC):
is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD:
# remind its consumer to get data before forward
if not is_last_stage:
@@ -470,8 +435,6 @@ class WorkerBase(ABC):
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
@@ -539,10 +502,6 @@ class WorkerBase(ABC):
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@@ -593,11 +552,6 @@ class WorkerBase(ABC):
with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock:
# assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item
@@ -605,11 +559,6 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step

View File

@@ -1,13 +1,12 @@
from typing import List, Callable, Dict
import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage

View File

@@ -1,25 +1,15 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import argparse
import os
import warnings
import argparse
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
# config for debug and test
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: