Merge remote-tracking branch 'origin/main' into store_connector

This commit is contained in:
aries-ckt 2023-05-24 16:19:33 +08:00
commit 4ef9ada7dc
2 changed files with 16 additions and 1 deletions

14
pilot/configs/__init__.py Normal file
View File

@ -0,0 +1,14 @@
import os
import random
import sys
from dotenv import load_dotenv
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42")
random.seed(42)
# Load the users .env file into environment variables
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@ -9,6 +9,7 @@ from typing import Optional
from pilot.model.compression import compress_module
from pilot.model.adapter import get_llm_model_adapter
from pilot.utils import get_gpu_memory
from pilot.configs.model_config import DEVICE
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
def raise_warning_for_incompatible_cpu_offloading_configuration(
@ -50,7 +51,7 @@ class ModelLoader(metaclass=Singleton):
def __init__(self,
model_path) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = DEVICE
self.model_path = model_path
self.kwargs = {
"torch_dtype": torch.float16,