mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-28 20:30:42 +00:00
add example of self-supervised SimCLR training - V2 (#50)
* add example of self-supervised SimCLR training * simclr v2, replace nvidia dali dataloader * updated * sync to latest code writing style * sync to latest code writing style and modify README * detail README & standardize dataset path
This commit is contained in:
15
examples/simclr_cifar10_data_parallel/myhooks.py
Normal file
15
examples/simclr_cifar10_data_parallel/myhooks.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class TotalBatchsizeHook(BaseHook):
|
||||
def __init__(self, priority: int = 2) -> None:
|
||||
super().__init__(priority)
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def before_train(self, trainer):
|
||||
total_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.config.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
||||
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|
||||
Reference in New Issue
Block a user