mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 17:10:03 +00:00
[misc] Use dist logger in plugins (#6011)
* use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
@@ -8,6 +7,8 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
SUPPORT_PEFT = False
|
||||
try:
|
||||
import peft
|
||||
@@ -81,12 +82,15 @@ class Booster:
|
||||
plugin, Plugin
|
||||
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
|
||||
self.plugin = plugin
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
# set accelerator
|
||||
if self.plugin and self.plugin.control_device():
|
||||
self.accelerator = None
|
||||
if device is not None:
|
||||
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
|
||||
self.logger.warning(
|
||||
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
|
||||
)
|
||||
else:
|
||||
device = device or "cuda"
|
||||
self.accelerator = Accelerator(device)
|
||||
@@ -94,7 +98,10 @@ class Booster:
|
||||
# set precision
|
||||
if self.plugin and self.plugin.control_precision():
|
||||
if mixed_precision is not None:
|
||||
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
|
||||
self.logger.warning(
|
||||
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
|
||||
ranks=[0],
|
||||
)
|
||||
self.mixed_precision = None
|
||||
elif mixed_precision is None:
|
||||
self.mixed_precision = None
|
||||
@@ -267,8 +274,9 @@ class Booster:
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
if quantize is True:
|
||||
if bnb_quantization_config is not None:
|
||||
warnings.warn(
|
||||
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
|
||||
self.logger.warning(
|
||||
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
|
||||
ranks=[0],
|
||||
)
|
||||
else:
|
||||
bnb_quantization_config = BnbQuantizationConfig(
|
||||
|
Reference in New Issue
Block a user