mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
ci: make ci happy lint the code, delete unused imports
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
server_error_msg = (
|
||||
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
)
|
||||
|
||||
handler = None
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
torch.cuda.device_count()
|
||||
if max_gpus is None
|
||||
if max_gpus is None
|
||||
else min(max_gpus, torch.cuda.device_count())
|
||||
)
|
||||
|
||||
@@ -29,14 +30,13 @@ def get_gpu_memory(max_gpus=None):
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
total_memory = gpu_properties.total_memory / (1024**3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
global handler
|
||||
|
||||
@@ -47,7 +47,7 @@ def build_logger(logger_name, logger_filename):
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO, encoding='utf-8')
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
@@ -70,7 +70,8 @@ def build_logger(logger_name, logger_filename):
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True)
|
||||
filename, when="D", utc=True
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
@@ -84,35 +85,36 @@ class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
|
||||
if line[-1] == "\n":
|
||||
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
|
||||
if self.linebuf != "":
|
||||
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ''
|
||||
self.linebuf = ""
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
@@ -120,6 +122,7 @@ def disable_torch_init():
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
@@ -128,4 +131,3 @@ def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
Reference in New Issue
Block a user