DB-GPT/dbgpt/core/awel/dag/loader.py
2024-05-30 18:51:57 +08:00

119 lines
3.3 KiB
Python

"""DAG loader.
DAGLoader will load DAGs from dag_dirs or other sources.
Now only support load DAGs from local files.
"""
import hashlib
import logging
import os
import sys
import traceback
from abc import ABC, abstractmethod
from typing import List
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
"""Abstract base class representing a loader for loading DAGs."""
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags."""
class LocalFileDAGLoader(DAGLoader):
"""DAG loader for loading DAGs from local files."""
def __init__(self, dag_dirs: List[str]) -> None:
"""Initialize a LocalFileDAGLoader.
Args:
dag_dirs (List[str]): The directories to load DAGs.
"""
self._dag_dirs = dag_dirs
def load_dags(self) -> List[DAG]:
"""Load dags from local files."""
dags = []
for filepath in self._dag_dirs:
if not os.path.exists(filepath):
continue
if os.path.isdir(filepath):
dags += _process_directory(filepath)
else:
dags += _process_file(filepath)
return dags
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(
filepath: str, mod_name: str | None = None, show_log: bool = True
):
import importlib
import importlib.machinery
import importlib.util
if show_log:
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
if mod_name is None:
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
if show_log:
logger.info(
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
)
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags