mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
119 lines
3.3 KiB
Python
119 lines
3.3 KiB
Python
"""DAG loader.
|
|
|
|
DAGLoader will load DAGs from dag_dirs or other sources.
|
|
Now only support load DAGs from local files.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import sys
|
|
import traceback
|
|
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
|
|
from .base import DAG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DAGLoader(ABC):
|
|
"""Abstract base class representing a loader for loading DAGs."""
|
|
|
|
@abstractmethod
|
|
def load_dags(self) -> List[DAG]:
|
|
"""Load dags."""
|
|
|
|
|
|
class LocalFileDAGLoader(DAGLoader):
|
|
"""DAG loader for loading DAGs from local files."""
|
|
|
|
def __init__(self, dag_dirs: List[str]) -> None:
|
|
"""Initialize a LocalFileDAGLoader.
|
|
|
|
Args:
|
|
dag_dirs (List[str]): The directories to load DAGs.
|
|
"""
|
|
self._dag_dirs = dag_dirs
|
|
|
|
def load_dags(self) -> List[DAG]:
|
|
"""Load dags from local files."""
|
|
dags = []
|
|
for filepath in self._dag_dirs:
|
|
if not os.path.exists(filepath):
|
|
continue
|
|
if os.path.isdir(filepath):
|
|
dags += _process_directory(filepath)
|
|
else:
|
|
dags += _process_file(filepath)
|
|
return dags
|
|
|
|
|
|
def _process_directory(directory: str) -> List[DAG]:
|
|
dags = []
|
|
for file in os.listdir(directory):
|
|
if file.endswith(".py"):
|
|
filepath = os.path.join(directory, file)
|
|
dags += _process_file(filepath)
|
|
return dags
|
|
|
|
|
|
def _process_file(filepath) -> List[DAG]:
|
|
mods = _load_modules_from_file(filepath)
|
|
results = _process_modules(mods)
|
|
return results
|
|
|
|
|
|
def _load_modules_from_file(
|
|
filepath: str, mod_name: str | None = None, show_log: bool = True
|
|
):
|
|
import importlib
|
|
import importlib.machinery
|
|
import importlib.util
|
|
|
|
if show_log:
|
|
logger.info(f"Importing {filepath}")
|
|
|
|
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
|
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
|
if mod_name is None:
|
|
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
|
|
|
if mod_name in sys.modules:
|
|
del sys.modules[mod_name]
|
|
|
|
def parse(mod_name, filepath):
|
|
try:
|
|
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
|
|
spec = importlib.util.spec_from_loader(mod_name, loader)
|
|
new_module = importlib.util.module_from_spec(spec)
|
|
sys.modules[spec.name] = new_module
|
|
loader.exec_module(new_module)
|
|
return [new_module]
|
|
except Exception:
|
|
msg = traceback.format_exc()
|
|
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
|
# TODO save error message
|
|
return []
|
|
|
|
return parse(mod_name, filepath)
|
|
|
|
|
|
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
|
|
top_level_dags = (
|
|
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
|
)
|
|
found_dags = []
|
|
for dag, mod in top_level_dags:
|
|
try:
|
|
# TODO validate dag params
|
|
if show_log:
|
|
logger.info(
|
|
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
|
|
)
|
|
found_dags.append(dag)
|
|
except Exception:
|
|
msg = traceback.format_exc()
|
|
logger.error(f"Failed to dag file, error message: {msg}")
|
|
return found_dags
|