mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 18:33:40 +00:00
cli[minor]: Improve partner migrations (#20938)
This auto generates partner migrations. At the moment the migration is from community -> partner. So one would need to run the migration script twice to go from langchain to partner.
This commit is contained in:
parent
5653f36adc
commit
12c906f6ce
@ -0,0 +1,18 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain_community.llms.anthropic.Anthropic",
|
||||||
|
"langchain_anthropic.Anthropic"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.anthropic.ChatAnthropic",
|
||||||
|
"langchain_anthropic.ChatAnthropic"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.Anthropic",
|
||||||
|
"langchain_anthropic.Anthropic"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.ChatAnthropic",
|
||||||
|
"langchain_anthropic.ChatAnthropic"
|
||||||
|
]
|
||||||
|
]
|
@ -0,0 +1,18 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain_community.llms.fireworks.Fireworks",
|
||||||
|
"langchain_fireworks.Fireworks"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.fireworks.ChatFireworks",
|
||||||
|
"langchain_fireworks.ChatFireworks"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.Fireworks",
|
||||||
|
"langchain_fireworks.Fireworks"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.ChatFireworks",
|
||||||
|
"langchain_fireworks.ChatFireworks"
|
||||||
|
]
|
||||||
|
]
|
@ -0,0 +1,10 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain_community.llms.watsonxllm.WatsonxLLM",
|
||||||
|
"langchain_ibm.WatsonxLLM"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.WatsonxLLM",
|
||||||
|
"langchain_ibm.WatsonxLLM"
|
||||||
|
]
|
||||||
|
]
|
@ -0,0 +1,50 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain_community.llms.openai.OpenAI",
|
||||||
|
"langchain_openai.OpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.openai.AzureOpenAI",
|
||||||
|
"langchain_openai.AzureOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.embeddings.openai.OpenAIEmbeddings",
|
||||||
|
"langchain_openai.OpenAIEmbeddings"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
|
||||||
|
"langchain_openai.AzureOpenAIEmbeddings"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.openai.ChatOpenAI",
|
||||||
|
"langchain_openai.ChatOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
|
||||||
|
"langchain_openai.AzureChatOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.AzureOpenAI",
|
||||||
|
"langchain_openai.AzureOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.llms.OpenAI",
|
||||||
|
"langchain_openai.OpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.embeddings.AzureOpenAIEmbeddings",
|
||||||
|
"langchain_openai.AzureOpenAIEmbeddings"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.embeddings.OpenAIEmbeddings",
|
||||||
|
"langchain_openai.OpenAIEmbeddings"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.AzureChatOpenAI",
|
||||||
|
"langchain_openai.AzureChatOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.chat_models.ChatOpenAI",
|
||||||
|
"langchain_openai.ChatOpenAI"
|
||||||
|
]
|
||||||
|
]
|
@ -0,0 +1,10 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain_community.vectorstores.pinecone.Pinecone",
|
||||||
|
"langchain_pinecone.Pinecone"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain_community.vectorstores.Pinecone",
|
||||||
|
"langchain_pinecone.Pinecone"
|
||||||
|
]
|
||||||
|
]
|
@ -1,14 +0,0 @@
|
|||||||
[
|
|
||||||
[
|
|
||||||
"langchain.chat_models.ChatOpenAI",
|
|
||||||
"langchain_openai.ChatOpenAI"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"langchain.chat_models.ChatOpenAI",
|
|
||||||
"langchain_openai.ChatOpenAI"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"langchain.chat_models.ChatAnthropic",
|
|
||||||
"langchain_anthropic.ChatAnthropic"
|
|
||||||
]
|
|
||||||
]
|
|
@ -26,7 +26,7 @@ HERE = os.path.dirname(__file__)
|
|||||||
|
|
||||||
|
|
||||||
def _load_migrations_by_file(path: str):
|
def _load_migrations_by_file(path: str):
|
||||||
migrations_path = os.path.join(HERE, path)
|
migrations_path = os.path.join(HERE, "migrations", path)
|
||||||
with open(migrations_path, "r", encoding="utf-8") as f:
|
with open(migrations_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data
|
return data
|
||||||
@ -43,21 +43,30 @@ def _deduplicate_in_order(
|
|||||||
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
||||||
|
|
||||||
|
|
||||||
def _load_migrations():
|
PARTNERS = [
|
||||||
"""Load the migrations from the JSON file."""
|
"anthropic.json",
|
||||||
# Later earlier ones have higher precedence.
|
"ibm.json",
|
||||||
paths = [
|
"openai.json",
|
||||||
"migrations_v0.2_partner.json",
|
"pinecone.json",
|
||||||
"migrations_v0.2.json",
|
"fireworks.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
|
||||||
|
"""Load migrations from fixtures."""
|
||||||
|
paths: List[str] = PARTNERS + ["langchain.json"]
|
||||||
data = []
|
data = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
data.extend(_load_migrations_by_file(path))
|
data.extend(_load_migrations_by_file(path))
|
||||||
|
|
||||||
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migrations():
|
||||||
|
"""Load the migrations from the JSON file."""
|
||||||
|
# Later earlier ones have higher precedence.
|
||||||
imports: Dict[str, Tuple[str, str]] = {}
|
imports: Dict[str, Tuple[str, str]] = {}
|
||||||
|
data = _load_migrations_from_fixtures()
|
||||||
|
|
||||||
for old_path, new_path in data:
|
for old_path, new_path in data:
|
||||||
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
@ -11,6 +12,7 @@ from langchain_cli.namespaces.migrate.generate.utils import (
|
|||||||
COMMUNITY_PKG,
|
COMMUNITY_PKG,
|
||||||
find_subclasses_in_module,
|
find_subclasses_in_module,
|
||||||
list_classes_by_package,
|
list_classes_by_package,
|
||||||
|
list_init_imports_by_package,
|
||||||
)
|
)
|
||||||
|
|
||||||
# PUBLIC API
|
# PUBLIC API
|
||||||
@ -29,13 +31,24 @@ def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
|
|||||||
"""
|
"""
|
||||||
package = importlib.import_module(pkg_name)
|
package = importlib.import_module(pkg_name)
|
||||||
classes_ = find_subclasses_in_module(
|
classes_ = find_subclasses_in_module(
|
||||||
package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore]
|
package,
|
||||||
|
[
|
||||||
|
BaseLanguageModel,
|
||||||
|
Embeddings,
|
||||||
|
BaseRetriever,
|
||||||
|
VectorStore,
|
||||||
|
BaseDocumentTransformer,
|
||||||
|
BaseDocumentCompressor,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
|
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
|
||||||
|
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
|
||||||
|
|
||||||
|
old_paths = community_classes + imports_for_pkg
|
||||||
|
|
||||||
migrations = [
|
migrations = [
|
||||||
(f"{community_module}.{community_class}", f"{pkg_name}.{community_class}")
|
(f"{module}.{item}", f"{pkg_name}.{item}")
|
||||||
for community_module, community_class in community_classes
|
for module, item in old_paths
|
||||||
if community_class in classes_
|
if item in classes_
|
||||||
]
|
]
|
||||||
return migrations
|
return migrations
|
||||||
|
@ -3,7 +3,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Tuple, Type
|
from typing import Any, List, Optional, Tuple, Type
|
||||||
|
|
||||||
HERE = Path(__file__).parent
|
HERE = Path(__file__).parent
|
||||||
# Should bring us to [root]/src
|
# Should bring us to [root]/src
|
||||||
@ -15,13 +15,15 @@ PARTNER_PKGS = PKGS_ROOT / "partners"
|
|||||||
|
|
||||||
|
|
||||||
class ImportExtractor(ast.NodeVisitor):
|
class ImportExtractor(ast.NodeVisitor):
|
||||||
def __init__(self, *, from_package: str) -> None:
|
def __init__(self, *, from_package: Optional[str] = None) -> None:
|
||||||
"""Extract all imports from the given package."""
|
"""Extract all imports from the given code, optionally filtering by package."""
|
||||||
self.imports = []
|
self.imports = []
|
||||||
self.package = from_package
|
self.package = from_package
|
||||||
|
|
||||||
def visit_ImportFrom(self, node):
|
def visit_ImportFrom(self, node):
|
||||||
if node.module and str(node.module).startswith(self.package):
|
if node.module and (
|
||||||
|
self.package is None or str(node.module).startswith(self.package)
|
||||||
|
):
|
||||||
for alias in node.names:
|
for alias in node.names:
|
||||||
self.imports.append((node.module, alias.name))
|
self.imports.append((node.module, alias.name))
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
@ -72,13 +74,40 @@ def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
|
|||||||
code = f.read()
|
code = f.read()
|
||||||
module_name = _get_current_module(file, pkg)
|
module_name = _get_current_module(file, pkg)
|
||||||
class_names = _get_class_names(code)
|
class_names = _get_class_names(code)
|
||||||
|
|
||||||
return [(module_name, class_name) for class_name in class_names]
|
return [(module_name, class_name) for class_name in class_names]
|
||||||
|
|
||||||
|
|
||||||
|
def identify_all_imports_in_file(
|
||||||
|
file: str, *, from_package: Optional[str] = None
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
"""Let's also identify all the imports in the given file."""
|
||||||
|
with open(file, encoding="utf-8") as f:
|
||||||
|
code = f.read()
|
||||||
|
return find_imports_from_package(code, from_package=from_package)
|
||||||
|
|
||||||
|
|
||||||
|
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
|
||||||
|
"""Identify the source of the package.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pkg_root: the root of the package. This contains source + tests, and other
|
||||||
|
things like pyproject.toml, lock files etc
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns the path to the source code for the package.
|
||||||
|
"""
|
||||||
|
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
|
||||||
|
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
|
||||||
|
assert len(matching_dirs) == 1, "There should be only one langchain package."
|
||||||
|
return matching_dirs[0]
|
||||||
|
|
||||||
|
|
||||||
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
||||||
"""List all classes in a package."""
|
"""List all classes in a package."""
|
||||||
module_classes = []
|
module_classes = []
|
||||||
files = list(Path(pkg_root).rglob("*.py"))
|
pkg_source = identify_pkg_source(pkg_root)
|
||||||
|
files = list(pkg_source.rglob("*.py"))
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
rel_path = os.path.relpath(file, pkg_root)
|
rel_path = os.path.relpath(file, pkg_root)
|
||||||
@ -88,11 +117,29 @@ def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
|||||||
return module_classes
|
return module_classes
|
||||||
|
|
||||||
|
|
||||||
def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]:
|
def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
||||||
|
"""List all the things that are being imported in a package by module."""
|
||||||
|
imports = []
|
||||||
|
pkg_source = identify_pkg_source(pkg_root)
|
||||||
|
# Scan all the files in the package
|
||||||
|
files = list(Path(pkg_source).rglob("*.py"))
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if not file.name == "__init__.py":
|
||||||
|
continue
|
||||||
|
import_in_file = identify_all_imports_in_file(str(file))
|
||||||
|
module_name = _get_current_module(file, pkg_root)
|
||||||
|
imports.extend([(module_name, item) for _, item in import_in_file])
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
def find_imports_from_package(
|
||||||
|
code: str, *, from_package: Optional[str] = None
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
# Parse the code into an AST
|
# Parse the code into an AST
|
||||||
tree = ast.parse(code)
|
tree = ast.parse(code)
|
||||||
# Create an instance of the visitor
|
# Create an instance of the visitor
|
||||||
extractor = ImportExtractor(from_package="langchain_community")
|
extractor = ImportExtractor(from_package=from_package)
|
||||||
# Use the visitor to update the imports list
|
# Use the visitor to update the imports list
|
||||||
extractor.visit(tree)
|
extractor.visit(tree)
|
||||||
return extractor.imports
|
return extractor.imports
|
||||||
|
@ -50,7 +50,7 @@ select = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.poe.tasks]
|
[tool.poe.tasks]
|
||||||
test = "poetry run pytest"
|
test = "poetry run pytest tests"
|
||||||
watch = "poetry run ptw"
|
watch = "poetry run ptw"
|
||||||
version = "poetry version --short"
|
version = "poetry version --short"
|
||||||
bump = ["_bump_1", "_bump_2"]
|
bump = ["_bump_1", "_bump_2"]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Script to generate migrations for the migration script."""
|
"""Script to generate migrations for the migration script."""
|
||||||
import json
|
import json
|
||||||
|
import pkgutil
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -38,10 +39,39 @@ def partner(pkg: str, output: str) -> None:
|
|||||||
"""Generate migration scripts specifically for LangChain modules."""
|
"""Generate migration scripts specifically for LangChain modules."""
|
||||||
click.echo("Migration script for LangChain generated.")
|
click.echo("Migration script for LangChain generated.")
|
||||||
migrations = get_migrations_for_partner_package(pkg)
|
migrations = get_migrations_for_partner_package(pkg)
|
||||||
output_name = f"partner_{pkg}.json" if output is None else output
|
# Run with python 3.9+
|
||||||
|
output_name = f"{pkg.removeprefix('langchain_')}.json" if output is None else output
|
||||||
|
if migrations:
|
||||||
with open(output_name, "w") as f:
|
with open(output_name, "w") as f:
|
||||||
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||||
click.secho(f"LangChain migration script saved to {output_name}")
|
click.secho(f"LangChain migration script saved to {output_name}")
|
||||||
|
else:
|
||||||
|
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def all_installed_partner_pkgs() -> None:
|
||||||
|
"""Generate migration scripts for all LangChain modules."""
|
||||||
|
# Will generate migrations for all pather packages.
|
||||||
|
# Define as "langchain_<partner_name>".
|
||||||
|
# First let's determine which packages are installed in the environment
|
||||||
|
# and then generate migrations for them.
|
||||||
|
langchain_pkgs = [
|
||||||
|
name
|
||||||
|
for _, name, _ in pkgutil.iter_modules()
|
||||||
|
if name.startswith("langchain_")
|
||||||
|
and name not in {"langchain_core", "langchain_cli", "langchain_community"}
|
||||||
|
]
|
||||||
|
for pkg in langchain_pkgs:
|
||||||
|
migrations = get_migrations_for_partner_package(pkg)
|
||||||
|
# Run with python 3.9+
|
||||||
|
output_name = f"{pkg.removeprefix('langchain_')}.json"
|
||||||
|
if migrations:
|
||||||
|
with open(output_name, "w") as f:
|
||||||
|
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||||
|
click.secho(f"LangChain migration script saved to {output_name}")
|
||||||
|
else:
|
||||||
|
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1 +0,0 @@
|
|||||||
[["langchain_community.embeddings.openai.OpenAIEmbeddings", "langchain_openai.embeddings.base.OpenAIEmbeddings"], ["langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings"], ["langchain_community.chat_models.openai.ChatOpenAI", "langchain_openai.chat_models.base.ChatOpenAI"], ["langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_openai.chat_models.azure.AzureChatOpenAI"]]
|
|
@ -1,7 +1,7 @@
|
|||||||
from tests.unit_tests.migrate.integration.case import Case
|
from tests.unit_tests.migrate.cli_runner.case import Case
|
||||||
from tests.unit_tests.migrate.integration.cases import imports
|
from tests.unit_tests.migrate.cli_runner.cases import imports
|
||||||
from tests.unit_tests.migrate.integration.file import File
|
from tests.unit_tests.migrate.cli_runner.file import File
|
||||||
from tests.unit_tests.migrate.integration.folder import Folder
|
from tests.unit_tests.migrate.cli_runner.folder import Folder
|
||||||
|
|
||||||
cases = [
|
cases = [
|
||||||
Case(
|
Case(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from tests.unit_tests.migrate.integration.case import Case
|
from tests.unit_tests.migrate.cli_runner.case import Case
|
||||||
from tests.unit_tests.migrate.integration.file import File
|
from tests.unit_tests.migrate.cli_runner.file import File
|
||||||
|
|
||||||
cases = [
|
cases = [
|
||||||
Case(
|
Case(
|
||||||
@ -7,7 +7,7 @@ cases = [
|
|||||||
source=File(
|
source=File(
|
||||||
"app.py",
|
"app.py",
|
||||||
content=[
|
content=[
|
||||||
"from langchain.chat_models import ChatOpenAI",
|
"from langchain_community.chat_models import ChatOpenAI",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"class foo:",
|
"class foo:",
|
||||||
|
@ -28,4 +28,19 @@ def test_generate_migrations() -> None:
|
|||||||
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
|
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
|
||||||
"langchain_openai.AzureChatOpenAI",
|
"langchain_openai.AzureChatOpenAI",
|
||||||
),
|
),
|
||||||
|
("langchain_community.llms.AzureOpenAI", "langchain_openai.AzureOpenAI"),
|
||||||
|
("langchain_community.llms.OpenAI", "langchain_openai.OpenAI"),
|
||||||
|
(
|
||||||
|
"langchain_community.embeddings.AzureOpenAIEmbeddings",
|
||||||
|
"langchain_openai.AzureOpenAIEmbeddings",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"langchain_community.embeddings.OpenAIEmbeddings",
|
||||||
|
"langchain_openai.OpenAIEmbeddings",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"langchain_community.chat_models.AzureChatOpenAI",
|
||||||
|
"langchain_openai.AzureChatOpenAI",
|
||||||
|
),
|
||||||
|
("langchain_community.chat_models.ChatOpenAI", "langchain_openai.ChatOpenAI"),
|
||||||
]
|
]
|
||||||
|
@ -19,6 +19,16 @@ class TestReplaceImportsCommand(CodemodTest):
|
|||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
"""
|
"""
|
||||||
after = """
|
after = """
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
"""
|
||||||
|
self.assertCodemod(before, after)
|
||||||
|
|
||||||
|
def test_from_community_to_partner(self) -> None:
|
||||||
|
"""Test that we can replace imports from community to partner."""
|
||||||
|
before = """
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
"""
|
||||||
|
after = """
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
"""
|
"""
|
||||||
self.assertCodemod(before, after)
|
self.assertCodemod(before, after)
|
||||||
@ -31,10 +41,10 @@ class TestReplaceImportsCommand(CodemodTest):
|
|||||||
|
|
||||||
def test_mixed_imports(self) -> None:
|
def test_mixed_imports(self) -> None:
|
||||||
before = """
|
before = """
|
||||||
from langchain.chat_models import ChatOpenAI, ChatAnthropic, foo
|
from langchain_community.chat_models import ChatOpenAI, ChatAnthropic, foo
|
||||||
"""
|
"""
|
||||||
after = """
|
after = """
|
||||||
from langchain.chat_models import foo
|
from langchain_community.chat_models import foo
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user