mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 11:48:53 +00:00
[pipeline] All bert models (#4233)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
This commit is contained in:
parent
a14d352088
commit
e7cc62d735
@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
if torch.__version__ >= "1.13.0":
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
|
||||
else:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
||||
|
@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
|
@ -42,6 +42,8 @@ _POLICY_LIST = {
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
|
||||
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
|
||||
|
||||
# LLaMA
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
llama_model = self.model.model
|
||||
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
|
||||
# tie weights
|
||||
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
from .torchrec import *
|
||||
#from .torchrec import *
|
||||
|
@ -87,6 +87,17 @@ def data_gen_for_mcq():
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
def data_gen_for_qa():
|
||||
# generating data for question answering
|
||||
# no need for labels and use start and end position instead
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_question_answering',
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
@ -7,6 +7,7 @@ from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
output = bert_for_pretraining_forward(
|
||||
self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward():
|
||||
output = bert_for_pretraining_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output[0].shape)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
# assert output[1].shape == (2, 768)
|
||||
|
||||
|
@ -7,12 +7,13 @@ from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_lmhead_forward():
|
||||
def check_bert_lm_head_model_forward():
|
||||
configuration = BertConfig()
|
||||
model = BertLMHeadModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_lmhead_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_lmhead_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
output = bert_lm_head_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 30522)
|
||||
|
||||
@ -93,7 +98,7 @@ def check_bert_lmhead_policy():
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bert_lmhead_forward()
|
||||
check_bert_lm_head_model_forward()
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert_lmhead_forward():
|
||||
def test_bert_lm_head_model_forward():
|
||||
spawn(run_dist_model, 4)
|
||||
|
||||
|
||||
@ -115,5 +120,5 @@ def test_bert_lmhead_policy():
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bert for pretraining model forward and bert for pretraining model policy"""
|
||||
test_bert_lmhead_forward()
|
||||
test_bert_lm_head_model_forward()
|
||||
test_bert_lmhead_policy()
|
@ -6,12 +6,14 @@ from transformers.models.bert.modeling_bert import BertModel
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bert_model_forward():
|
||||
# this test may crash for internet reasons
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
@ -34,20 +36,25 @@ def check_bert_model_forward():
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
print(output['hidden_states'].shape)
|
||||
output = bert_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
assert output['hidden_states'].shape == (2, 3, 768)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bert_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 768)
|
||||
|
||||
@ -112,4 +119,3 @@ if __name__ == "__main__":
|
||||
"""test the bert model forward and bert model policy"""
|
||||
#test_bert_model_forward()
|
||||
test_bert_model_policy()
|
||||
# this test need config to run
|
||||
|
@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
164
tests/test_shardformer/test_model/test_pure_pipeline.py
Normal file
164
tests/test_shardformer/test_model/test_pure_pipeline.py
Normal file
@ -0,0 +1,164 @@
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class PipelineOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module):
|
||||
super().__init__(optim)
|
||||
params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
# TODO: support amp
|
||||
|
||||
|
||||
class PipelinedModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
shardformer = ShardFormer(shard_config)
|
||||
module, self.shared_params = shardformer.optimize(module)
|
||||
self.shared_param_process_groups = []
|
||||
super().__init__(module)
|
||||
|
||||
|
||||
def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0):
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
#rank=self.pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def execute_pipeline(
|
||||
data_iter: Iterator,
|
||||
model: PipelinedModel,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: PipelineOptimizer,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False,
|
||||
schedule: OneForwardOneBackwardSchedule = None,
|
||||
) -> dict:
|
||||
# return loss or outputs if needed
|
||||
outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class data_iter():
|
||||
|
||||
def __getitem__(self, x):
|
||||
return torch.randint(0, 100, (4, 128)).cuda()
|
||||
|
||||
|
||||
def loss(x, y):
|
||||
return (x[0].float().mean() - y[0].float().mean())
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
from datasets import load_dataset
|
||||
|
||||
#dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi")
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
num_microbatches = 2
|
||||
org_model = model_fn().cuda()
|
||||
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
|
||||
#dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4)
|
||||
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
|
||||
pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
|
||||
data_it = iter(data_iter())
|
||||
results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule)
|
||||
if stage_manager.is_last_stage():
|
||||
assert results['loss'] is not None
|
||||
assert results['outputs'] is None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name == 'transformers_bert':
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
if name == 'transformers_bert_for_mcq':
|
||||
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (6, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
|
||||
output = sharded_model(input_ids=x,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape == (2, 3)
|
||||
else:
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
# one batch, 2 single sentences, each sentence has 3 tokens
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
# print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
# print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 128)
|
||||
assert output[0].shape[0] == 2
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user