[misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-08-20 10:32:41 +08:00
committed by GitHub
parent f1c3266a94
commit dcc44aab8d
8 changed files with 101 additions and 70 deletions

View File

@@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
@@ -8,6 +7,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.logging import get_dist_logger
SUPPORT_PEFT = False
try:
import peft
@@ -81,12 +82,15 @@ class Booster:
plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
self.logger = get_dist_logger()
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else:
device = device or "cuda"
self.accelerator = Accelerator(device)
@@ -94,7 +98,10 @@ class Booster:
# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -267,8 +274,9 @@ class Booster:
), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
)
else:
bnb_quantization_config = BnbQuantizationConfig(