mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 10:06:27 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			105 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			105 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import inspect
 | |
| import sys
 | |
| from importlib.machinery import SourceFileLoader
 | |
| from pathlib import Path
 | |
| from colossalai.logging import get_dist_logger
 | |
| 
 | |
| 
 | |
| class Config(dict):
 | |
|     """This is a wrapper class for dict objects so that values of which can be
 | |
|     accessed as attributes.
 | |
| 
 | |
|     :param config: The dict object to be wrapped
 | |
|     :type config: dict
 | |
|     """
 | |
| 
 | |
|     def __init__(self, config: dict = None):
 | |
|         if config is not None:
 | |
|             for k, v in config.items():
 | |
|                 self._add_item(k, v)
 | |
| 
 | |
|     def __missing__(self, key):
 | |
|         raise KeyError(key)
 | |
| 
 | |
|     def __getattr__(self, key):
 | |
|         try:
 | |
|             value = super(Config, self).__getitem__(key)
 | |
|             return value
 | |
|         except KeyError:
 | |
|             raise AttributeError(key)
 | |
| 
 | |
|     def __setattr__(self, key, value):
 | |
|         super(Config, self).__setitem__(key, value)
 | |
| 
 | |
|     def _add_item(self, key, value):
 | |
|         if isinstance(value, dict):
 | |
|             self.__setattr__(key, Config(value))
 | |
|         else:
 | |
|             self.__setattr__(key, value)
 | |
| 
 | |
|     def update(self, config):
 | |
|         assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
 | |
|         for k, v in config.items():
 | |
|             self._add_item(k, v)
 | |
|         return self
 | |
| 
 | |
|     @staticmethod
 | |
|     def from_file(filename: str):
 | |
|         """Reads a python file and constructs a corresponding :class:`Config` object.
 | |
| 
 | |
|         :param filename: Name of the file to construct the return object
 | |
|         :type filename: str
 | |
|         :raises AssertionError: Raises an AssertionError if the file does not exist, or the file
 | |
|             is not .py file
 | |
|         :return: A :class:`Config` object constructed with information in the file
 | |
|         :rtype: :class:`Config`
 | |
|         """
 | |
| 
 | |
|         # check config path
 | |
|         if isinstance(filename, str):
 | |
|             filepath = Path(filename).absolute()
 | |
|         elif isinstance(filename, Path):
 | |
|             filepath = filename.absolute()
 | |
| 
 | |
|         assert filepath.exists(), f'{filename} is not found, please check your configuration path'
 | |
| 
 | |
|         # check extension
 | |
|         extension = filepath.suffix
 | |
|         assert extension == '.py', 'only .py files are supported'
 | |
| 
 | |
|         # import the config as module
 | |
|         remove_path = False
 | |
|         if filepath.parent not in sys.path:
 | |
|             sys.path.insert(0, (filepath))
 | |
|             remove_path = True
 | |
| 
 | |
|         module_name = filepath.stem
 | |
|         source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
 | |
|         module = source_file.load_module()
 | |
| 
 | |
|         # load into config
 | |
|         config = Config()
 | |
| 
 | |
|         for k, v in module.__dict__.items():
 | |
|             if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
 | |
|                 continue
 | |
|             else:
 | |
|                 config._add_item(k, v)
 | |
| 
 | |
|         logger = get_dist_logger()
 | |
|         logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
 | |
| 
 | |
|         # remove module
 | |
|         del sys.modules[module_name]
 | |
|         if remove_path:
 | |
|             sys.path.pop(0)
 | |
| 
 | |
|         return config
 | |
| 
 | |
| 
 | |
| class ConfigException(Exception):
 | |
|     pass
 |