mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-17 08:51:59 +00:00
[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)
* add util functions for shardformer tests & rewrite gpt2 test * fix shared_params & embedding/merging * fix precision
This commit is contained in:
parent
5c6f183192
commit
b1feeced8e
@ -37,7 +37,8 @@ class HybridParallelModule(ModelWrapper):
|
|||||||
self.shared_param_process_groups = []
|
self.shared_param_process_groups = []
|
||||||
for shared_param in self.shared_params:
|
for shared_param in self.shared_params:
|
||||||
if len(shared_param) > 0:
|
if len(shared_param) > 0:
|
||||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
|
self.shared_param_process_groups.append(
|
||||||
|
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
|
||||||
if precision == 'fp16':
|
if precision == 'fp16':
|
||||||
module = module.half().cuda()
|
module = module.half().cuda()
|
||||||
elif precision == 'bf16':
|
elif precision == 'bf16':
|
||||||
|
@ -72,7 +72,9 @@ config = transformers.GPT2Config(n_layer=2,
|
|||||||
embd_pdrop=0,
|
embd_pdrop=0,
|
||||||
resid_pdrop=0,
|
resid_pdrop=0,
|
||||||
summary_first_dropout=0,
|
summary_first_dropout=0,
|
||||||
hidden_dropout=0)
|
hidden_dropout=0,
|
||||||
|
problem_type="single_label_classification",
|
||||||
|
pad_token_id=50256)
|
||||||
|
|
||||||
config_for_token_classification = copy.deepcopy(config)
|
config_for_token_classification = copy.deepcopy(config)
|
||||||
config_for_token_classification.num_labels = 2
|
config_for_token_classification.num_labels = 2
|
||||||
|
@ -1,11 +1,19 @@
|
|||||||
import copy
|
import copy
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
from torch.optim import Adam, Optimizer
|
||||||
|
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer._utils import getattr_
|
from colossalai.shardformer._utils import getattr_
|
||||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
|
|||||||
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
||||||
|
|
||||||
|
|
||||||
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
|
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
|
||||||
|
|
||||||
|
use_lazy_init = False
|
||||||
|
if 'use_lazy_init' in test_config:
|
||||||
|
use_lazy_init = test_config.pop('use_lazy_init')
|
||||||
|
|
||||||
|
if use_lazy_init:
|
||||||
|
ctx = LazyInitContext()
|
||||||
|
else:
|
||||||
|
ctx = nullcontext()
|
||||||
|
|
||||||
|
plugin = HybridParallelPlugin(**test_config)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
org_model = model_fn().cuda()
|
||||||
|
sharded_model = copy.deepcopy(org_model)
|
||||||
|
|
||||||
|
if use_lazy_init:
|
||||||
|
org_model = ctx.materialize(org_model)
|
||||||
|
|
||||||
|
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||||
|
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||||
|
criterion = loss_fn
|
||||||
|
|
||||||
|
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||||
|
|
||||||
|
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
|
||||||
|
|
||||||
|
|
||||||
|
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
|
||||||
|
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
|
||||||
|
booster: Booster):
|
||||||
|
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
outputs = output_transform_fn(outputs)
|
||||||
|
loss = criterion(outputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
data = data_gen_fn()
|
||||||
|
sharded_model.train()
|
||||||
|
if booster.plugin.stage_manager is not None:
|
||||||
|
data = {
|
||||||
|
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||||
|
for k, v in data.items()
|
||||||
|
}
|
||||||
|
data_iter = iter([data])
|
||||||
|
sharded_output = booster.execute_pipeline(data_iter,
|
||||||
|
sharded_model,
|
||||||
|
_criterion,
|
||||||
|
sharded_optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True)
|
||||||
|
sharded_loss = sharded_output['loss']
|
||||||
|
else:
|
||||||
|
data = {k: v.cuda() for k, v in data.items()}
|
||||||
|
sharded_output = sharded_model(**data)
|
||||||
|
sharded_loss = criterion(sharded_output)
|
||||||
|
sharded_loss.backward()
|
||||||
|
|
||||||
|
org_model.train()
|
||||||
|
org_output = org_model(**data)
|
||||||
|
org_loss = criterion(org_output)
|
||||||
|
org_loss.backward()
|
||||||
|
|
||||||
|
return org_loss, org_output, sharded_loss, sharded_output
|
||||||
|
|
||||||
|
|
||||||
|
def check_output_hidden_state(org_output: Tensor,
|
||||||
|
sharded_output: Tensor,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
atol: float = 1e-5,
|
||||||
|
rtol: float = 1e-3):
|
||||||
|
|
||||||
|
org_hidden_state = org_output.last_hidden_state
|
||||||
|
|
||||||
|
if stage_manager is None:
|
||||||
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
|
|
||||||
|
if stage_manager and stage_manager.is_last_stage():
|
||||||
|
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
|
||||||
|
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
||||||
|
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
|
||||||
|
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_weight(org_model: Module,
|
||||||
|
sharded_model: Module,
|
||||||
|
layer_suffix: List[str],
|
||||||
|
tp_group: Optional[ProcessGroup] = None,
|
||||||
|
dim: int = 0,
|
||||||
|
atol: float = 1e-5,
|
||||||
|
rtol: float = 1e-3,
|
||||||
|
verbose: bool = False):
|
||||||
|
|
||||||
for suffix in layer_suffix:
|
for suffix in layer_suffix:
|
||||||
org_grad = getattr_(original_model, suffix).weight.grad
|
org_weight = getattr_(org_model, suffix).weight
|
||||||
|
sharded_weight = getattr_(sharded_model, suffix).weight
|
||||||
|
|
||||||
|
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||||
|
sharded_weight_list = [
|
||||||
|
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||||
|
]
|
||||||
|
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
|
||||||
|
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
|
||||||
|
|
||||||
|
if verbose and dist.get_rank() == 0:
|
||||||
|
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
||||||
|
|
||||||
|
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
|
||||||
|
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad(org_model: Module,
|
||||||
|
sharded_model: Module,
|
||||||
|
layer_suffix: List[str],
|
||||||
|
tp_group: ProcessGroup = None,
|
||||||
|
dim: int = 0,
|
||||||
|
atol: float = 1e-5,
|
||||||
|
rtol: float = 1e-3,
|
||||||
|
verbose: bool = False):
|
||||||
|
|
||||||
|
for suffix in layer_suffix:
|
||||||
|
org_grad = getattr_(org_model, suffix).weight.grad
|
||||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||||
shard_weight = getattr_(sharded_model, suffix).weight
|
shard_weight = getattr_(sharded_model, suffix).weight
|
||||||
|
|
||||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
|
shard_grad_list = [
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
|
]
|
||||||
else:
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||||
all_shard_grad = shard_grad
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||||
|
|
||||||
|
# embedding may be resized when using tensor parallel
|
||||||
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||||
|
shard_grad = shard_grad[:org_grad.shape[0], :]
|
||||||
|
|
||||||
if verbose and dist.get_rank() == 0:
|
if verbose and dist.get_rank() == 0:
|
||||||
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
org_grad, all_shard_grad, rtol=rtol, atol=atol
|
org_grad, shard_grad, rtol=rtol, atol=atol
|
||||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"
|
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
@ -1,107 +1,48 @@
|
|||||||
import copy
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
from torch.optim import Adam
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
|
||||||
from colossalai.lazy.lazy_init import LazyInitContext
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.d_tensor.api import (
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
clear_layout_converter,
|
|
||||||
is_customized_distributed_tensor,
|
|
||||||
is_distributed_tensor,
|
|
||||||
)
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_grad,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
|
||||||
use_lazy_init = False
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||||
if 'use_lazy_init' in test_config:
|
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||||
use_lazy_init = test_config.pop('use_lazy_init')
|
|
||||||
|
|
||||||
if use_lazy_init:
|
|
||||||
ctx = LazyInitContext()
|
|
||||||
else:
|
|
||||||
ctx = nullcontext()
|
|
||||||
|
|
||||||
# prepare booster
|
org_loss, org_output, sharded_loss, sharded_output = \
|
||||||
plugin = HybridParallelPlugin(**test_config)
|
run_forward_backward_with_hybrid_plugin(
|
||||||
booster = Booster(plugin=plugin)
|
org_model,
|
||||||
stage_manager = plugin.stage_manager
|
sharded_model,
|
||||||
|
sharded_optimizer,
|
||||||
|
data_gen_fn,
|
||||||
|
output_transform_fn,
|
||||||
|
criterion,
|
||||||
|
booster)
|
||||||
|
|
||||||
# prepare models and optimizers
|
stage_manager = booster.plugin.stage_manager
|
||||||
with ctx:
|
tp_group = booster.plugin.tp_group
|
||||||
org_model = model_fn().cuda()
|
|
||||||
sharded_model = copy.deepcopy(org_model)
|
|
||||||
|
|
||||||
if use_lazy_init:
|
|
||||||
org_model = ctx.materialize(org_model)
|
|
||||||
|
|
||||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
|
||||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
|
||||||
criterion = loss_fn
|
|
||||||
|
|
||||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
|
||||||
|
|
||||||
def _criterion(outputs, inputs):
|
|
||||||
outputs = output_transform_fn(outputs)
|
|
||||||
loss = criterion(outputs)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
# do forward and backward
|
|
||||||
data = data_gen_fn()
|
|
||||||
sharded_model.train()
|
|
||||||
if stage_manager:
|
|
||||||
data = {
|
|
||||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
data_iter = iter([data])
|
|
||||||
sharded_output = booster.execute_pipeline(data_iter,
|
|
||||||
sharded_model,
|
|
||||||
_criterion,
|
|
||||||
sharded_optimizer,
|
|
||||||
return_loss=True,
|
|
||||||
return_outputs=True)
|
|
||||||
sharded_loss = sharded_output['loss']
|
|
||||||
else:
|
|
||||||
data = {k: v.cuda() for k, v in data.items()}
|
|
||||||
sharded_output = sharded_model(**data)
|
|
||||||
sharded_loss = criterion(sharded_output)
|
|
||||||
sharded_loss.backward()
|
|
||||||
|
|
||||||
org_model.train()
|
|
||||||
org_output = org_model(**data)
|
|
||||||
org_loss = criterion(org_output)
|
|
||||||
org_loss.backward()
|
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
if stage_manager is None or stage_manager.is_last_stage():
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
|
||||||
# check last hidden state
|
|
||||||
if org_model.__class__.__name__ == 'GPT2Model':
|
if org_model.__class__.__name__ == 'GPT2Model':
|
||||||
org_hidden_state = org_output.last_hidden_state
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
if stage_manager is None:
|
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
|
||||||
sharded_hidden_state = sharded_output.last_hidden_state
|
|
||||||
|
|
||||||
if stage_manager and stage_manager.is_last_stage():
|
|
||||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
|
|
||||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
|
||||||
|
|
||||||
# check loss
|
|
||||||
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
|
|
||||||
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
|
|
||||||
|
|
||||||
# unwrap model
|
# unwrap model
|
||||||
if org_model.__class__.__name__ == 'GPT2Model':
|
if org_model.__class__.__name__ == 'GPT2Model':
|
||||||
@ -111,27 +52,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
gpt2 = org_model.transformer
|
gpt2 = org_model.transformer
|
||||||
sharded_gpt2 = sharded_model.unwrap().transformer
|
sharded_gpt2 = sharded_model.unwrap().transformer
|
||||||
|
|
||||||
# check grad
|
|
||||||
col_layer_for_check = ['h[0].mlp.c_fc']
|
col_layer_for_check = ['h[0].mlp.c_fc']
|
||||||
row_layer_for_check = ['h[0].mlp.c_proj']
|
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||||
check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
|
|
||||||
check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
|
# check grad
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
|
||||||
|
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
|
||||||
|
|
||||||
# check weights after optimizer.step()
|
# check weights after optimizer.step()
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
if stage_manager is None or stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
|
||||||
org_weight = org_model.h[0].mlp.c_fc.weight
|
|
||||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
|
||||||
|
|
||||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
||||||
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
|
|
||||||
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
|
|
||||||
shard_weight = torch.cat(shard_weight_list, dim=1)
|
|
||||||
|
|
||||||
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
|
|
||||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -156,9 +89,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_gpt2_test(test_config):
|
def run_gpt2_test(test_config):
|
||||||
|
|
||||||
# TODO: add plugin_config for TP+DP after supporting & debugging it
|
# TODO: add test_config for TP+DP after supporting & debugging it
|
||||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||||
|
|
||||||
|
# TODO: add test_config for flash attention & jit operator after supporting
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||||
|
|
||||||
@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port):
|
|||||||
run_gpt2_test()
|
run_gpt2_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('Have some bug caused by merge')
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
Loading…
Reference in New Issue
Block a user