mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
chore(cli): select ALL rules with exclusions (#31936)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
09a616fe85
commit
cf2b4bbe09
@ -13,7 +13,7 @@ def create_demo_server(
|
|||||||
*,
|
*,
|
||||||
config_keys: Sequence[str] = (),
|
config_keys: Sequence[str] = (),
|
||||||
playground_type: Literal["default", "chat"] = "default",
|
playground_type: Literal["default", "chat"] = "default",
|
||||||
):
|
) -> FastAPI:
|
||||||
"""Create a demo server for the current template."""
|
"""Create a demo server for the current template."""
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
package_root = get_package_root()
|
package_root = get_package_root()
|
||||||
@ -40,11 +40,11 @@ def create_demo_server(
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def create_demo_server_configurable():
|
def create_demo_server_configurable() -> FastAPI:
|
||||||
"""Create a configurable demo server."""
|
"""Create a configurable demo server."""
|
||||||
return create_demo_server(config_keys=["configurable"])
|
return create_demo_server(config_keys=["configurable"])
|
||||||
|
|
||||||
|
|
||||||
def create_demo_server_chat():
|
def create_demo_server_chat() -> FastAPI:
|
||||||
"""Create a chat demo server."""
|
"""Create a chat demo server."""
|
||||||
return create_demo_server(playground_type="chat")
|
return create_demo_server(playground_type="chat")
|
||||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
from langchain_cli.utils.events import create_events
|
from langchain_cli.utils.events import create_events
|
||||||
from langchain_cli.utils.git import (
|
from langchain_cli.utils.git import (
|
||||||
@ -261,7 +262,7 @@ def add(
|
|||||||
cmd = ["pip", "install", "-e", *installed_destination_strs]
|
cmd = ["pip", "install", "-e", *installed_destination_strs]
|
||||||
cmd_str = " \\\n ".join(installed_destination_strs)
|
cmd_str = " \\\n ".join(installed_destination_strs)
|
||||||
typer.echo(f"Running: pip install -e \\\n {cmd_str}")
|
typer.echo(f"Running: pip install -e \\\n {cmd_str}")
|
||||||
subprocess.run(cmd, cwd=cwd) # noqa: S603
|
subprocess.run(cmd, cwd=cwd, check=True) # noqa: S603
|
||||||
|
|
||||||
chain_names = []
|
chain_names = []
|
||||||
for e in installed_exports:
|
for e in installed_exports:
|
||||||
@ -367,8 +368,6 @@ def serve(
|
|||||||
app_str = app if app is not None else "app.server:app"
|
app_str = app if app is not None else "app.server:app"
|
||||||
host_str = host if host is not None else "127.0.0.1"
|
host_str = host if host is not None else "127.0.0.1"
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app_str,
|
app_str,
|
||||||
host=host_str,
|
host=host_str,
|
||||||
|
@ -129,6 +129,7 @@ def new(
|
|||||||
subprocess.run(
|
subprocess.run(
|
||||||
["poetry", "install", "--with", "lint,test,typing,test_integration"], # noqa: S607
|
["poetry", "install", "--with", "lint,test,typing,test_integration"], # noqa: S607
|
||||||
cwd=destination_dir,
|
cwd=destination_dir,
|
||||||
|
check=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# confirm src and dst are the same length
|
# confirm src and dst are the same length
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
|
||||||
def generate_raw_migrations(
|
def generate_raw_migrations(
|
||||||
@ -89,7 +90,7 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
|
|||||||
items = []
|
items = []
|
||||||
|
|
||||||
# Function to handle importing from modules
|
# Function to handle importing from modules
|
||||||
def handle_module(module, module_name) -> None:
|
def handle_module(module: ModuleType, module_name: str) -> None:
|
||||||
if hasattr(module, "__all__"):
|
if hasattr(module, "__all__"):
|
||||||
all_objects = module.__all__
|
all_objects = module.__all__
|
||||||
for name in all_objects:
|
for name in all_objects:
|
||||||
|
@ -5,10 +5,9 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
HERE = Path(__file__).parent
|
HERE = Path(__file__).parent
|
||||||
# Should bring us to [root]/src
|
# Should bring us to [root]/src
|
||||||
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
|
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
|
||||||
@ -19,15 +18,14 @@ PARTNER_PKGS = PKGS_ROOT / "partners"
|
|||||||
|
|
||||||
|
|
||||||
class ImportExtractor(ast.NodeVisitor):
|
class ImportExtractor(ast.NodeVisitor):
|
||||||
"""Import extractor"""
|
"""Import extractor."""
|
||||||
|
|
||||||
def __init__(self, *, from_package: Optional[str] = None) -> None:
|
def __init__(self, *, from_package: Optional[str] = None) -> None:
|
||||||
"""Extract all imports from the given code, optionally filtering by package."""
|
"""Extract all imports from the given code, optionally filtering by package."""
|
||||||
self.imports: list = []
|
self.imports: list = []
|
||||||
self.package = from_package
|
self.package = from_package
|
||||||
|
|
||||||
@override
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802
|
||||||
def visit_ImportFrom(self, node) -> None:
|
|
||||||
if node.module and (
|
if node.module and (
|
||||||
self.package is None or str(node.module).startswith(self.package)
|
self.package is None or str(node.module).startswith(self.package)
|
||||||
):
|
):
|
||||||
@ -46,7 +44,7 @@ def _get_class_names(code: str) -> list[str]:
|
|||||||
|
|
||||||
# Define a node visitor class to collect class names
|
# Define a node visitor class to collect class names
|
||||||
class ClassVisitor(ast.NodeVisitor):
|
class ClassVisitor(ast.NodeVisitor):
|
||||||
def visit_ClassDef(self, node) -> None: # noqa: N802
|
def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802
|
||||||
class_names.append(node.name)
|
class_names.append(node.name)
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
@ -65,7 +63,7 @@ def is_subclass(class_obj: Any, classes_: list[type]) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_subclasses_in_module(module, classes_: list[type]) -> list[str]:
|
def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
|
||||||
"""Find all classes in the module that inherit from one of the classes."""
|
"""Find all classes in the module that inherit from one of the classes."""
|
||||||
subclasses = []
|
subclasses = []
|
||||||
# Iterate over all attributes of the module that are classes
|
# Iterate over all attributes of the module that are classes
|
||||||
@ -77,8 +75,7 @@ def find_subclasses_in_module(module, classes_: list[type]) -> list[str]:
|
|||||||
|
|
||||||
def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]:
|
def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]:
|
||||||
"""Extract all class names from a file."""
|
"""Extract all class names from a file."""
|
||||||
with open(file, encoding="utf-8") as f:
|
code = Path(file).read_text(encoding="utf-8")
|
||||||
code = f.read()
|
|
||||||
module_name = _get_current_module(file, pkg)
|
module_name = _get_current_module(file, pkg)
|
||||||
class_names = _get_class_names(code)
|
class_names = _get_class_names(code)
|
||||||
|
|
||||||
@ -91,8 +88,7 @@ def identify_all_imports_in_file(
|
|||||||
from_package: Optional[str] = None,
|
from_package: Optional[str] = None,
|
||||||
) -> list[tuple[str, str]]:
|
) -> list[tuple[str, str]]:
|
||||||
"""Let's also identify all the imports in the given file."""
|
"""Let's also identify all the imports in the given file."""
|
||||||
with open(file, encoding="utf-8") as f:
|
code = Path(file).read_text(encoding="utf-8")
|
||||||
code = f.read()
|
|
||||||
return find_imports_from_package(code, from_package=from_package)
|
return find_imports_from_package(code, from_package=from_package)
|
||||||
|
|
||||||
|
|
||||||
@ -162,8 +158,7 @@ def find_imports_from_package(
|
|||||||
|
|
||||||
def _get_current_module(path: Path, pkg_root: str) -> str:
|
def _get_current_module(path: Path, pkg_root: str) -> str:
|
||||||
"""Convert a path to a module name."""
|
"""Convert a path to a module name."""
|
||||||
path_as_pathlib = pathlib.Path(os.path.abspath(path))
|
relative_path = path.relative_to(pkg_root).with_suffix("")
|
||||||
relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("")
|
|
||||||
posix_path = relative_path.as_posix()
|
posix_path = relative_path.as_posix()
|
||||||
norm_path = os.path.normpath(str(posix_path))
|
norm_path = os.path.normpath(str(posix_path))
|
||||||
fully_qualified_module = norm_path.replace("/", ".")
|
fully_qualified_module = norm_path.replace("/", ".")
|
||||||
|
@ -7,7 +7,9 @@ from pathlib import Path
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from langchain_cli.utils.github import list_packages
|
||||||
from langchain_cli.utils.packages import get_langserve_export, get_package_root
|
from langchain_cli.utils.packages import get_langserve_export, get_package_root
|
||||||
|
|
||||||
package_cli = typer.Typer(no_args_is_help=True, add_completion=False)
|
package_cli = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||||
@ -79,7 +81,7 @@ def new(
|
|||||||
|
|
||||||
# poetry install
|
# poetry install
|
||||||
if with_poetry:
|
if with_poetry:
|
||||||
subprocess.run(["poetry", "install"], cwd=destination_dir) # noqa: S607
|
subprocess.run(["poetry", "install"], cwd=destination_dir, check=True) # noqa: S607
|
||||||
|
|
||||||
|
|
||||||
@package_cli.command()
|
@package_cli.command()
|
||||||
@ -128,8 +130,6 @@ def serve(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
script,
|
script,
|
||||||
factory=True,
|
factory=True,
|
||||||
@ -142,8 +142,6 @@ def serve(
|
|||||||
@package_cli.command()
|
@package_cli.command()
|
||||||
def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: # noqa: A001
|
def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: # noqa: A001
|
||||||
"""List all or search for available templates."""
|
"""List all or search for available templates."""
|
||||||
from langchain_cli.utils.github import list_packages
|
|
||||||
|
|
||||||
packages = list_packages(contains=contains)
|
packages = list_packages(contains=contains)
|
||||||
for package in packages:
|
for package in packages:
|
||||||
typer.echo(package)
|
typer.echo(package)
|
||||||
|
@ -16,6 +16,7 @@ class EventDict(TypedDict):
|
|||||||
event: The name of the event.
|
event: The name of the event.
|
||||||
properties: Optional dictionary of event properties.
|
properties: Optional dictionary of event properties.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event: str
|
event: str
|
||||||
properties: Optional[dict[str, Any]]
|
properties: Optional[dict[str, Any]]
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Git utilities."""
|
"""Git utilities."""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -15,6 +16,8 @@ from langchain_cli.constants import (
|
|||||||
DEFAULT_GIT_SUBDIRECTORY,
|
DEFAULT_GIT_SUBDIRECTORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DependencySource(TypedDict):
|
class DependencySource(TypedDict):
|
||||||
"""Dependency source information."""
|
"""Dependency source information."""
|
||||||
@ -181,16 +184,15 @@ def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
|
|||||||
# try pulling
|
# try pulling
|
||||||
try:
|
try:
|
||||||
repo = Repo(repo_path)
|
repo = Repo(repo_path)
|
||||||
if repo.active_branch.name != ref:
|
if repo.active_branch.name == ref:
|
||||||
raise ValueError
|
repo.remotes.origin.pull()
|
||||||
repo.remotes.origin.pull()
|
return repo_path
|
||||||
except Exception:
|
except Exception:
|
||||||
# if it fails, delete and clone again
|
logger.exception("Failed to pull existing repo")
|
||||||
shutil.rmtree(repo_path)
|
# if it fails, delete and clone again
|
||||||
Repo.clone_from(gitstring, repo_path, branch=ref, depth=1)
|
shutil.rmtree(repo_path)
|
||||||
else:
|
|
||||||
Repo.clone_from(gitstring, repo_path, branch=ref, depth=1)
|
|
||||||
|
|
||||||
|
Repo.clone_from(gitstring, repo_path, branch=ref, depth=1)
|
||||||
return repo_path
|
return repo_path
|
||||||
|
|
||||||
|
|
||||||
@ -203,7 +205,7 @@ def copy_repo(
|
|||||||
Raises FileNotFound error if it can't find source
|
Raises FileNotFound error if it can't find source
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def ignore_func(_, files):
|
def ignore_func(_: str, files: list[str]) -> list[str]:
|
||||||
return [f for f in files if f == ".git"]
|
return [f for f in files if f == ".git"]
|
||||||
|
|
||||||
shutil.copytree(source, destination, ignore=ignore_func)
|
shutil.copytree(source, destination, ignore=ignore_func)
|
||||||
|
@ -39,7 +39,7 @@ class LangServeExport(TypedDict):
|
|||||||
|
|
||||||
def get_langserve_export(filepath: Path) -> LangServeExport:
|
def get_langserve_export(filepath: Path) -> LangServeExport:
|
||||||
"""Get LangServe export information from a pyproject.toml file."""
|
"""Get LangServe export information from a pyproject.toml file."""
|
||||||
with open(filepath) as f:
|
with filepath.open() as f:
|
||||||
data: dict[str, Any] = load(f)
|
data: dict[str, Any] = load(f)
|
||||||
try:
|
try:
|
||||||
module = data["tool"]["langserve"]["export_module"]
|
module = data["tool"]["langserve"]["export_module"]
|
||||||
|
@ -20,7 +20,7 @@ def add_dependencies_to_pyproject_toml(
|
|||||||
local_editable_dependencies: Iterable[tuple[str, Path]],
|
local_editable_dependencies: Iterable[tuple[str, Path]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add dependencies to pyproject.toml."""
|
"""Add dependencies to pyproject.toml."""
|
||||||
with open(pyproject_toml, encoding="utf-8") as f:
|
with pyproject_toml.open(encoding="utf-8") as f:
|
||||||
# tomlkit types aren't amazing - treat as Dict instead
|
# tomlkit types aren't amazing - treat as Dict instead
|
||||||
pyproject: dict[str, Any] = load(f)
|
pyproject: dict[str, Any] = load(f)
|
||||||
pyproject["tool"]["poetry"]["dependencies"].update(
|
pyproject["tool"]["poetry"]["dependencies"].update(
|
||||||
@ -29,7 +29,7 @@ def add_dependencies_to_pyproject_toml(
|
|||||||
for name, loc in local_editable_dependencies
|
for name, loc in local_editable_dependencies
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with open(pyproject_toml, "w", encoding="utf-8") as f:
|
with pyproject_toml.open("w", encoding="utf-8") as f:
|
||||||
dump(pyproject, f)
|
dump(pyproject, f)
|
||||||
|
|
||||||
|
|
||||||
@ -38,12 +38,13 @@ def remove_dependencies_from_pyproject_toml(
|
|||||||
local_editable_dependencies: Iterable[str],
|
local_editable_dependencies: Iterable[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove dependencies from pyproject.toml."""
|
"""Remove dependencies from pyproject.toml."""
|
||||||
with open(pyproject_toml, encoding="utf-8") as f:
|
with pyproject_toml.open(encoding="utf-8") as f:
|
||||||
pyproject: dict[str, Any] = load(f)
|
pyproject: dict[str, Any] = load(f)
|
||||||
# tomlkit types aren't amazing - treat as Dict instead
|
# tomlkit types aren't amazing - treat as Dict instead
|
||||||
dependencies = pyproject["tool"]["poetry"]["dependencies"]
|
dependencies = pyproject["tool"]["poetry"]["dependencies"]
|
||||||
for name in local_editable_dependencies:
|
for name in local_editable_dependencies:
|
||||||
with contextlib.suppress(KeyError):
|
with contextlib.suppress(KeyError):
|
||||||
del dependencies[name]
|
del dependencies[name]
|
||||||
with open(pyproject_toml, "w", encoding="utf-8") as f:
|
|
||||||
|
with pyproject_toml.open("w", encoding="utf-8") as f:
|
||||||
dump(pyproject, f)
|
dump(pyproject, f)
|
||||||
|
@ -48,53 +48,40 @@ exclude = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [ "ALL",]
|
||||||
"A", # flake8-builtins
|
|
||||||
"B", # flake8-bugbear
|
|
||||||
"ARG", # flake8-unused-arguments
|
|
||||||
"ASYNC", # flake8-async
|
|
||||||
"C4", # flake8-comprehensions
|
|
||||||
"COM", # flake8-commas
|
|
||||||
"D1", # pydocstyle
|
|
||||||
"E", # pycodestyle error
|
|
||||||
"EM", # flake8-errmsg
|
|
||||||
"F", # pyflakes
|
|
||||||
"FA", # flake8-future-annotations
|
|
||||||
"FBT", # flake8-boolean-trap
|
|
||||||
"FLY", # flake8-flynt
|
|
||||||
"I", # isort
|
|
||||||
"ICN", # flake8-import-conventions
|
|
||||||
"INT", # flake8-gettext
|
|
||||||
"ISC", # isort-comprehensions
|
|
||||||
"N", # pep8-naming
|
|
||||||
"PT", # flake8-pytest-style
|
|
||||||
"PGH", # pygrep-hooks
|
|
||||||
"PIE", # flake8-pie
|
|
||||||
"PERF", # flake8-perf
|
|
||||||
"PYI", # flake8-pyi
|
|
||||||
"Q", # flake8-quotes
|
|
||||||
"RET", # flake8-return
|
|
||||||
"RSE", # flake8-rst-docstrings
|
|
||||||
"RUF", # ruff
|
|
||||||
"S", # flake8-bandit
|
|
||||||
"SLF", # flake8-self
|
|
||||||
"SLOT", # flake8-slots
|
|
||||||
"SIM", # flake8-simplify
|
|
||||||
"T10", # flake8-debugger
|
|
||||||
"T20", # flake8-print
|
|
||||||
"TID", # flake8-tidy-imports
|
|
||||||
"UP", # pyupgrade
|
|
||||||
"W", # pycodestyle warning
|
|
||||||
"YTT", # flake8-2020
|
|
||||||
]
|
|
||||||
ignore = [
|
ignore = [
|
||||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
"C90", # McCabe complexity
|
||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
|
"FIX002", # Line contains TODO
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"PLR09", # Too many something (arg, statements, etc)
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"TC001", # Doesn't play well with Pydantic
|
||||||
|
"TC002", # Doesn't play well with Pydantic
|
||||||
|
"TC003", # Doesn't play well with Pydantic
|
||||||
|
"TD002", # Missing author in TODO
|
||||||
|
"TD003", # Missing issue link in TODO
|
||||||
|
|
||||||
|
# TODO rules
|
||||||
|
"ANN401",
|
||||||
|
"BLE",
|
||||||
|
"D1",
|
||||||
]
|
]
|
||||||
|
unfixable = [
|
||||||
|
"B028", # People should intentionally tune the stacklevel
|
||||||
|
"PLW1510", # People should intentionally set the check argument
|
||||||
|
]
|
||||||
|
|
||||||
|
flake8-annotations.allow-star-arg-any = true
|
||||||
|
flake8-annotations.mypy-init-return = true
|
||||||
|
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||||
|
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
||||||
|
pydocstyle.convention = "google"
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"tests/**" = [ "D1"]
|
"tests/**" = [ "D1", "S", "SLF",]
|
||||||
|
"scripts/**" = [ "INP", "S",]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
"""Script to generate migrations for the migration script."""
|
"""Script to generate migrations for the migration script."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
@ -73,8 +73,7 @@ def generic(
|
|||||||
else:
|
else:
|
||||||
dumped = dump_migrations_as_grit(name, migrations)
|
dumped = dump_migrations_as_grit(name, migrations)
|
||||||
|
|
||||||
with open(output, "w") as f:
|
Path(output).write_text(dumped)
|
||||||
f.write(dumped)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_partner(pkg: str, output: Optional[str] = None) -> None:
|
def handle_partner(pkg: str, output: Optional[str] = None) -> None:
|
||||||
@ -85,8 +84,7 @@ def handle_partner(pkg: str, output: Optional[str] = None) -> None:
|
|||||||
data = dump_migrations_as_grit(name, migrations)
|
data = dump_migrations_as_grit(name, migrations)
|
||||||
output_name = f"{name}.grit" if output is None else output
|
output_name = f"{name}.grit" if output is None else output
|
||||||
if migrations:
|
if migrations:
|
||||||
with open(output_name, "w") as f:
|
Path(output_name).write_text(data)
|
||||||
f.write(data)
|
|
||||||
click.secho(f"LangChain migration script saved to {output_name}")
|
click.secho(f"LangChain migration script saved to {output_name}")
|
||||||
else:
|
else:
|
||||||
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||||
@ -105,13 +103,13 @@ def partner(pkg: str, output: str) -> None:
|
|||||||
@click.argument("json_file")
|
@click.argument("json_file")
|
||||||
def json_to_grit(json_file: str) -> None:
|
def json_to_grit(json_file: str) -> None:
|
||||||
"""Generate a Grit migration from an old JSON migration file."""
|
"""Generate a Grit migration from an old JSON migration file."""
|
||||||
with open(json_file) as f:
|
file = Path(json_file)
|
||||||
|
with file.open() as f:
|
||||||
migrations = json.load(f)
|
migrations = json.load(f)
|
||||||
name = os.path.basename(json_file).removesuffix(".json").removesuffix(".grit")
|
name = file.stem
|
||||||
data = dump_migrations_as_grit(name, migrations)
|
data = dump_migrations_as_grit(name, migrations)
|
||||||
output_name = f"{name}.grit"
|
output_name = f"{name}.grit"
|
||||||
with open(output_name, "w") as f:
|
Path(output_name).write_text(data)
|
||||||
f.write(data)
|
|
||||||
click.secho(f"GritQL migration script saved to {output_name}")
|
click.secho(f"GritQL migration script saved to {output_name}")
|
||||||
|
|
||||||
|
|
||||||
|
0
libs/cli/tests/integration_tests/__init__.py
Normal file
0
libs/cli/tests/integration_tests/__init__.py
Normal file
@ -14,3 +14,6 @@ class File:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return self.content == __value.content
|
return self.content == __value.content
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((self.name, self.content))
|
||||||
|
@ -57,3 +57,6 @@ class Folder:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((self.name, tuple(self.files)))
|
||||||
|
Loading…
Reference in New Issue
Block a user