mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	[hotfix] fix autoparallel demo (#2533)
This commit is contained in:
		@@ -3,8 +3,9 @@ from torchvision.models import resnet50
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
import colossalai
 | 
			
		||||
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
 | 
			
		||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
 | 
			
		||||
from colossalai.core import global_context as gpc
 | 
			
		||||
from colossalai.device.device_mesh import DeviceMesh
 | 
			
		||||
from colossalai.logging import get_dist_logger
 | 
			
		||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
 | 
			
		||||
 | 
			
		||||
@@ -22,9 +23,14 @@ def main():
 | 
			
		||||
 | 
			
		||||
    # trace the model with meta data
 | 
			
		||||
    model = resnet50(num_classes=10).cuda()
 | 
			
		||||
    input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
 | 
			
		||||
 | 
			
		||||
    model = autoparallelize(model, input_sample)
 | 
			
		||||
    input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
 | 
			
		||||
    device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
 | 
			
		||||
    model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
 | 
			
		||||
 | 
			
		||||
    if gpc.get_global_rank() == 0:
 | 
			
		||||
        for node_strategy in solution:
 | 
			
		||||
            print(node_strategy)
 | 
			
		||||
    # build criterion
 | 
			
		||||
    criterion = torch.nn.CrossEntropyLoss()
 | 
			
		||||
 | 
			
		||||
@@ -52,6 +58,7 @@ def main():
 | 
			
		||||
            output = model(img)
 | 
			
		||||
            train_loss = criterion(output, label)
 | 
			
		||||
            train_loss.backward(train_loss)
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
        lr_scheduler.step()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user