Compare commits

...

4 Commits

Author SHA1 Message Date
Bagatur
93e59a031a fmt 2024-03-14 19:24:12 -07:00
Bagatur
8c08b66b22 Merge branch 'master' into bagatur/community_migration_script 2024-03-14 18:38:57 -07:00
Bagatur
8004e2efaf Merge branch 'master' into bagatur/community_migration_script 2024-03-14 14:55:00 -07:00
Bagatur
dedd47783b wip 2024-03-13 16:54:14 -07:00
12 changed files with 7321 additions and 20 deletions

View File

@@ -3,9 +3,7 @@ from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, DefaultDict, Dict, List, Optional, cast
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import (
ChatGeneration,
ChatResult,

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, cast
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain.prompts.prompt import PromptTemplate

View File

@@ -12,6 +12,8 @@
Serializable, Generation, PromptValue
""" # noqa: E501
from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.combining import CombiningOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
@@ -30,7 +32,6 @@ from langchain.output_parsers.openai_tools import (
)
from langchain.output_parsers.pandas_dataframe import PandasDataFrameOutputParser
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser

View File

@@ -20,6 +20,7 @@ the backbone of a retriever, but there are other types of retrievers as well.
import warnings
from typing import Any
from langchain_community.retrievers.outline import OutlineRetriever
from langchain_core._api import LangChainDeprecationWarning
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
@@ -27,7 +28,6 @@ from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.retrievers.outline import OutlineRetriever
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain.retrievers.re_phraser import RePhraseQueryRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever

View File

@@ -7,9 +7,7 @@ from langchain_core.callbacks import (
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
class ContextualCompressionRetriever(BaseRetriever):

View File

@@ -1,4 +1,4 @@
from langchain.retrievers.pubmed import PubMedRetriever
from langchain_community.retrievers.pubmed import PubMedRetriever
__all__ = [
"PubMedRetriever",

View File

@@ -88,11 +88,7 @@ or LangSmith's `RunEvaluator` classes.
- :func:`run_on_dataset <langchain.smith.evaluation.runner_utils.run_on_dataset>`: Function to evaluate a chain, agent, or other LangChain component over a dataset.
- :class:`RunEvalConfig <langchain.smith.evaluation.config.RunEvalConfig>`: Class representing the configuration for running evaluation. You can select evaluators by :class:`EvaluatorType <langchain.evaluation.schema.EvaluatorType>` or config, or you can pass in `custom_evaluators`
""" # noqa: E501
from langchain.smith.evaluation import (
RunEvalConfig,
arun_on_dataset,
run_on_dataset,
)
from langchain.smith.evaluation import RunEvalConfig, arun_on_dataset, run_on_dataset
__all__ = [
"arun_on_dataset",

View File

@@ -3,10 +3,9 @@ import re
from pathlib import Path
from typing import Iterator, List, Optional, Sequence, Tuple, Union
from langchain_community.storage.exceptions import InvalidKeyException
from langchain_core.stores import ByteStore
from langchain.storage.exceptions import InvalidKeyException
class LocalFileStore(ByteStore):
"""BaseStore interface that works on the local file system.

View File

@@ -4,6 +4,7 @@
These functions do not depend on any other LangChain module.
"""
from langchain_community.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain_core.utils.formatting import StrictFormatter, formatter
from langchain_core.utils.input import (
get_bolded_text,
@@ -22,7 +23,6 @@ from langchain_core.utils.utils import (
)
from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
__all__ = [

View File

@@ -0,0 +1,79 @@
import ast
import glob
import os
from pathlib import Path
CUR_DIR = Path(os.path.abspath(__file__)).parent
PARENT_DIR = CUR_DIR.parent
imports_to_migrate = []
for file_path in glob.glob(str(PARENT_DIR / "libs/langchain/langchain/**/*.py")):
curr_module = file_path[len(str(PARENT_DIR)) + 16 : -3].replace("/", ".")
if curr_module[-9:] == ".__init__":
curr_module = curr_module[:-9]
with open(file_path, "r") as file:
raw_contents = file.read()
module = ast.parse(raw_contents, filename=file_path)
for node in module.body:
if not isinstance(
node,
ast.ImportFrom,
):
continue
if "langchain_community" not in node.module:
continue
names = [n.__dict__ for n in node.names]
imports_to_migrate.append(
{"old_module": curr_module, "new_module": node.module, "imports": names}
)
if "def __getattr__" in raw_contents and "from langchain_community" in raw_contents:
direct_imports = [
alias.asname or alias.name
for n in module.body
if isinstance(n, (ast.Import, ast.ImportFrom))
for alias in n.names
]
getattr_fn = [
n
for n in module.body
if isinstance(n, ast.FunctionDef) and n.name == "__getattr__"
][0]
community_import = [
n
for n in getattr_fn.body
if isinstance(n, ast.ImportFrom) and "langchain_community" in n.module
]
if not community_import:
continue
community_import = community_import[0]
community_mod = community_import.module + "." + community_import.names[0].name
all_assign = [
n
for n in module.body
if isinstance(n, ast.Assign) and n.targets[0].id == "__all__"
][0]
all_val = set([e.value for e in all_assign.value.elts])
indirect_imports = all_val.difference(direct_imports)
indirect_aliases = [{"name": x, "asname": None} for x in indirect_imports]
imports_to_migrate.append(
{
"old_module": curr_module,
"new_module": community_mod,
"imports": indirect_aliases,
}
)
old_to_new = {
(x["old_module"], i["asname"] or i["name"]): {
"new_module": x["new_module"],
"name": i["name"],
"asname": i["asname"],
}
for x in imports_to_migrate
for i in x["imports"]
}
with open(CUR_DIR / "_old_to_new.py", "w") as f:
f.write(f"OLD_TO_NEW={old_to_new}")

7146
scripts/_old_to_new.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,86 @@
import ast
import glob
import sys
from pathlib import Path
from typing import Union
from _old_to_new import OLD_TO_NEW
OLD_TO_NEW_MODS_ONLY = {k[0] for k in OLD_TO_NEW}
def migrate(dir_: Union[str, Path]) -> None:
for file_path in glob.glob(str(Path(dir_).absolute() / "**/*.py")):
with open(file_path, "r") as file:
lines = file.readlines()
module = ast.parse("".join(lines), filename=file_path, type_comments=True)
new_lines = []
for i, node in enumerate(module.body):
if not isinstance(node, ast.ImportFrom) or (
node.module not in OLD_TO_NEW_MODS_ONLY
):
continue
new_import_froms = []
for imported in node.names:
if (node.module, imported.name) in OLD_TO_NEW:
new_import = OLD_TO_NEW[(node.module, imported.name)]
asname = imported.asname or new_import["asname"]
new_alias = ast.alias(name=new_import["name"], as_name=asname)
if new_import["new_module"] in [
getattr(nb, "module", None) for nb in new_import_froms
]:
existing_new_node = [
nb
for nb in new_import_froms
if getattr(nb, "module", None) == new_import["new_module"]
][0]
existing_new_node.names.append(new_alias)
else:
node_params = node.__dict__.copy()
node_params.pop("module")
node_params.pop("names")
new_import_froms.append(
ast.ImportFrom(
module=new_import["new_module"],
names=[new_alias],
**node_params,
)
)
else:
if node.module in [
getattr(nb, "module", None) for nb in new_import_froms
]:
existing_node = [
nb
for nb in new_import_froms
if getattr(nb, "module", None) == node.module
][0]
existing_node.names.append(imported)
else:
node_params = node.__dict__.copy()
node_params["names"] = [imported]
new_import_froms.append(ast.ImportFrom(**node_params))
_str = ast.unparse(
ast.fix_missing_locations(
ast.Module(body=new_import_froms, type_ignores=[])
)
)
_lines = [x + "\n" for x in _str.split("\n")]
new_lines.append((node.lineno - 1, node.end_lineno, _lines))
final_lines = []
last_end = 0
if new_lines:
for start, end, _lines in new_lines:
final_lines += lines[last_end:start] + _lines
last_end = end
final_lines += lines[last_end:]
else:
final_lines = lines
with open(file_path, "w") as file:
file.write("".join(final_lines))
if __name__ == "__main__":
migrate(sys.argv[1])