cli[minor]: Add ipynb support, add text_splitters (#20963)

This commit is contained in:
Eugene Yurtsev 2024-04-29 10:11:21 -04:00 committed by GitHub
parent 5e0b6b3e75
commit d781560722
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2632 additions and 6525 deletions

View File

@ -5,21 +5,41 @@ from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from langchain_cli.namespaces.migrate.codemods.replace_imports import ( from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod, generate_import_replacer,
) )
class Rule(str, Enum): class Rule(str, Enum):
R001 = "R001" langchain_to_community = "langchain_to_community"
"""Replace imports that have been moved.""" """Replace deprecated langchain imports with current ones in community."""
langchain_to_core = "langchain_to_core"
"""Replace deprecated langchain imports with current ones in core."""
langchain_to_text_splitters = "langchain_to_text_splitters"
"""Replace deprecated langchain imports with current ones in text splitters."""
community_to_core = "community_to_core"
"""Replace deprecated community imports with current ones in core."""
community_to_partner = "community_to_partner"
"""Replace deprecated community imports with current ones in partner."""
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]: def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
"""Gather codemods based on the disabled rules."""
codemods: List[Type[ContextAwareTransformer]] = [] codemods: List[Type[ContextAwareTransformer]] = []
if Rule.R001 not in disabled: # Import rules
codemods.append(ReplaceImportsCodemod) import_rules = {
Rule.langchain_to_community,
Rule.langchain_to_core,
Rule.community_to_core,
Rule.community_to_partner,
Rule.langchain_to_text_splitters,
}
# Find active import rules
active_import_rules = import_rules - set(disabled)
if active_import_rules:
codemods.append(generate_import_replacer(active_import_rules))
# Those codemods need to be the last ones. # Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor]) codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods return codemods

View File

@ -0,0 +1,110 @@
[
[
"langchain_community.callbacks.tracers.ConsoleCallbackHandler",
"langchain_core.tracers.ConsoleCallbackHandler"
],
[
"langchain_community.callbacks.tracers.FunctionCallbackHandler",
"langchain_core.tracers.stdout.FunctionCallbackHandler"
],
[
"langchain_community.callbacks.tracers.LangChainTracer",
"langchain_core.tracers.LangChainTracer"
],
[
"langchain_community.callbacks.tracers.LangChainTracerV1",
"langchain_core.tracers.langchain_v1.LangChainTracerV1"
],
[
"langchain_community.docstore.document.Document",
"langchain_core.documents.Document"
],
[
"langchain_community.document_loaders.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.document_loaders.base.BaseBlobParser",
"langchain_core.document_loaders.BaseBlobParser"
],
[
"langchain_community.document_loaders.base.BaseLoader",
"langchain_core.document_loaders.BaseLoader"
],
[
"langchain_community.document_loaders.blob_loaders.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.blob_loaders.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.document_loaders.blob_loaders.schema.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.blob_loaders.schema.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.tools.BaseTool",
"langchain_core.tools.BaseTool"
],
[
"langchain_community.tools.StructuredTool",
"langchain_core.tools.StructuredTool"
],
[
"langchain_community.tools.Tool",
"langchain_core.tools.Tool"
],
[
"langchain_community.tools.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.tool",
"langchain_core.tools.tool"
],
[
"langchain_community.tools.convert_to_openai.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.convert_to_openai.format_tool_to_openai_tool",
"langchain_core.utils.function_calling.format_tool_to_openai_tool"
],
[
"langchain_community.tools.render.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.render.format_tool_to_openai_tool",
"langchain_core.utils.function_calling.format_tool_to_openai_tool"
],
[
"langchain_community.utils.openai_functions.FunctionDescription",
"langchain_core.utils.function_calling.FunctionDescription"
],
[
"langchain_community.utils.openai_functions.ToolDescription",
"langchain_core.utils.function_calling.ToolDescription"
],
[
"langchain_community.utils.openai_functions.convert_pydantic_to_openai_function",
"langchain_core.utils.function_calling.convert_pydantic_to_openai_function"
],
[
"langchain_community.utils.openai_functions.convert_pydantic_to_openai_tool",
"langchain_core.utils.function_calling.convert_pydantic_to_openai_tool"
],
[
"langchain_community.vectorstores.VectorStore",
"langchain_core.vectorstores.VectorStore"
]
]

View File

@ -383,314 +383,10 @@
"langchain.agents.agent_toolkits.steam.toolkit.SteamToolkit", "langchain.agents.agent_toolkits.steam.toolkit.SteamToolkit",
"langchain_community.agent_toolkits.SteamToolkit" "langchain_community.agent_toolkits.SteamToolkit"
], ],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.BaseToolkit",
"langchain_community.agent_toolkits.base.BaseToolkit"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.VectorStoreQATool",
"langchain_community.tools.VectorStoreQATool"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.VectorStoreQAWithSourcesTool",
"langchain_community.tools.VectorStoreQAWithSourcesTool"
],
[ [
"langchain.agents.agent_toolkits.zapier.toolkit.ZapierToolkit", "langchain.agents.agent_toolkits.zapier.toolkit.ZapierToolkit",
"langchain_community.agent_toolkits.ZapierToolkit" "langchain_community.agent_toolkits.ZapierToolkit"
], ],
[
"langchain.agents.load_tools.ArxivAPIWrapper",
"langchain_community.utilities.ArxivAPIWrapper"
],
[
"langchain.agents.load_tools.ArxivQueryRun",
"langchain_community.tools.ArxivQueryRun"
],
[
"langchain.agents.load_tools.BaseGraphQLTool",
"langchain_community.tools.BaseGraphQLTool"
],
[
"langchain.agents.load_tools.BingSearchAPIWrapper",
"langchain_community.utilities.BingSearchAPIWrapper"
],
[
"langchain.agents.load_tools.BingSearchRun",
"langchain_community.tools.BingSearchRun"
],
[
"langchain.agents.load_tools.DallEAPIWrapper",
"langchain_community.utilities.dalle_image_generator.DallEAPIWrapper"
],
[
"langchain.agents.load_tools.DataForSeoAPISearchResults",
"langchain_community.tools.dataforseo_api_search.tool.DataForSeoAPISearchResults"
],
[
"langchain.agents.load_tools.DataForSeoAPISearchRun",
"langchain_community.tools.dataforseo_api_search.tool.DataForSeoAPISearchRun"
],
[
"langchain.agents.load_tools.DataForSeoAPIWrapper",
"langchain_community.utilities.dataforseo_api_search.DataForSeoAPIWrapper"
],
[
"langchain.agents.load_tools.DuckDuckGoSearchAPIWrapper",
"langchain_community.utilities.DuckDuckGoSearchAPIWrapper"
],
[
"langchain.agents.load_tools.DuckDuckGoSearchRun",
"langchain_community.tools.DuckDuckGoSearchRun"
],
[
"langchain.agents.load_tools.ElevenLabsText2SpeechTool",
"langchain_community.tools.ElevenLabsText2SpeechTool"
],
[
"langchain.agents.load_tools.GoldenQueryAPIWrapper",
"langchain_community.utilities.GoldenQueryAPIWrapper"
],
[
"langchain.agents.load_tools.GoldenQueryRun",
"langchain_community.tools.golden_query.tool.GoldenQueryRun"
],
[
"langchain.agents.load_tools.GoogleCloudTextToSpeechTool",
"langchain_community.tools.GoogleCloudTextToSpeechTool"
],
[
"langchain.agents.load_tools.GoogleFinanceAPIWrapper",
"langchain_community.utilities.GoogleFinanceAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleFinanceQueryRun",
"langchain_community.tools.google_finance.tool.GoogleFinanceQueryRun"
],
[
"langchain.agents.load_tools.GoogleJobsAPIWrapper",
"langchain_community.utilities.GoogleJobsAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleJobsQueryRun",
"langchain_community.tools.google_jobs.tool.GoogleJobsQueryRun"
],
[
"langchain.agents.load_tools.GoogleLensAPIWrapper",
"langchain_community.utilities.GoogleLensAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleLensQueryRun",
"langchain_community.tools.google_lens.tool.GoogleLensQueryRun"
],
[
"langchain.agents.load_tools.GoogleScholarAPIWrapper",
"langchain_community.utilities.GoogleScholarAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleScholarQueryRun",
"langchain_community.tools.google_scholar.tool.GoogleScholarQueryRun"
],
[
"langchain.agents.load_tools.GoogleSearchAPIWrapper",
"langchain_community.utilities.GoogleSearchAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleSearchResults",
"langchain_community.tools.GoogleSearchResults"
],
[
"langchain.agents.load_tools.GoogleSearchRun",
"langchain_community.tools.GoogleSearchRun"
],
[
"langchain.agents.load_tools.GoogleSerperAPIWrapper",
"langchain_community.utilities.GoogleSerperAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleSerperResults",
"langchain_community.tools.GoogleSerperResults"
],
[
"langchain.agents.load_tools.GoogleSerperRun",
"langchain_community.tools.GoogleSerperRun"
],
[
"langchain.agents.load_tools.GoogleTrendsAPIWrapper",
"langchain_community.utilities.GoogleTrendsAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleTrendsQueryRun",
"langchain_community.tools.google_trends.tool.GoogleTrendsQueryRun"
],
[
"langchain.agents.load_tools.GraphQLAPIWrapper",
"langchain_community.utilities.GraphQLAPIWrapper"
],
[
"langchain.agents.load_tools.HumanInputRun",
"langchain_community.tools.HumanInputRun"
],
[
"langchain.agents.load_tools.LambdaWrapper",
"langchain_community.utilities.LambdaWrapper"
],
[
"langchain.agents.load_tools.Memorize",
"langchain_community.tools.memorize.tool.Memorize"
],
[
"langchain.agents.load_tools.MerriamWebsterAPIWrapper",
"langchain_community.utilities.MerriamWebsterAPIWrapper"
],
[
"langchain.agents.load_tools.MerriamWebsterQueryRun",
"langchain_community.tools.MerriamWebsterQueryRun"
],
[
"langchain.agents.load_tools.MetaphorSearchAPIWrapper",
"langchain_community.utilities.MetaphorSearchAPIWrapper"
],
[
"langchain.agents.load_tools.MetaphorSearchResults",
"langchain_community.tools.MetaphorSearchResults"
],
[
"langchain.agents.load_tools.OpenWeatherMapAPIWrapper",
"langchain_community.utilities.OpenWeatherMapAPIWrapper"
],
[
"langchain.agents.load_tools.OpenWeatherMapQueryRun",
"langchain_community.tools.OpenWeatherMapQueryRun"
],
[
"langchain.agents.load_tools.PubMedAPIWrapper",
"langchain_community.utilities.PubMedAPIWrapper"
],
[
"langchain.agents.load_tools.PubmedQueryRun",
"langchain_community.tools.PubmedQueryRun"
],
[
"langchain.agents.load_tools.RedditSearchAPIWrapper",
"langchain_community.utilities.reddit_search.RedditSearchAPIWrapper"
],
[
"langchain.agents.load_tools.RedditSearchRun",
"langchain_community.tools.RedditSearchRun"
],
[
"langchain.agents.load_tools.RequestsDeleteTool",
"langchain_community.tools.RequestsDeleteTool"
],
[
"langchain.agents.load_tools.RequestsGetTool",
"langchain_community.tools.RequestsGetTool"
],
[
"langchain.agents.load_tools.RequestsPatchTool",
"langchain_community.tools.RequestsPatchTool"
],
[
"langchain.agents.load_tools.RequestsPostTool",
"langchain_community.tools.RequestsPostTool"
],
[
"langchain.agents.load_tools.RequestsPutTool",
"langchain_community.tools.RequestsPutTool"
],
[
"langchain.agents.load_tools.SceneXplainTool",
"langchain_community.tools.SceneXplainTool"
],
[
"langchain.agents.load_tools.SearchAPIResults",
"langchain_community.tools.SearchAPIResults"
],
[
"langchain.agents.load_tools.SearchAPIRun",
"langchain_community.tools.SearchAPIRun"
],
[
"langchain.agents.load_tools.SearchApiAPIWrapper",
"langchain_community.utilities.SearchApiAPIWrapper"
],
[
"langchain.agents.load_tools.SearxSearchResults",
"langchain_community.tools.SearxSearchResults"
],
[
"langchain.agents.load_tools.SearxSearchRun",
"langchain_community.tools.SearxSearchRun"
],
[
"langchain.agents.load_tools.SearxSearchWrapper",
"langchain_community.utilities.SearxSearchWrapper"
],
[
"langchain.agents.load_tools.SerpAPIWrapper",
"langchain_community.utilities.SerpAPIWrapper"
],
[
"langchain.agents.load_tools.ShellTool",
"langchain_community.tools.ShellTool"
],
[
"langchain.agents.load_tools.SleepTool",
"langchain_community.tools.SleepTool"
],
[
"langchain.agents.load_tools.StackExchangeAPIWrapper",
"langchain_community.utilities.StackExchangeAPIWrapper"
],
[
"langchain.agents.load_tools.StackExchangeTool",
"langchain_community.tools.StackExchangeTool"
],
[
"langchain.agents.load_tools.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.agents.load_tools.TwilioAPIWrapper",
"langchain_community.utilities.TwilioAPIWrapper"
],
[
"langchain.agents.load_tools.WikipediaAPIWrapper",
"langchain_community.utilities.WikipediaAPIWrapper"
],
[
"langchain.agents.load_tools.WikipediaQueryRun",
"langchain_community.tools.WikipediaQueryRun"
],
[
"langchain.agents.load_tools.WolframAlphaAPIWrapper",
"langchain_community.utilities.WolframAlphaAPIWrapper"
],
[
"langchain.agents.load_tools.WolframAlphaQueryRun",
"langchain_community.tools.WolframAlphaQueryRun"
],
[
"langchain.agents.react.base.Docstore",
"langchain_community.docstore.base.Docstore"
],
[
"langchain.agents.self_ask_with_search.base.GoogleSerperAPIWrapper",
"langchain_community.utilities.GoogleSerperAPIWrapper"
],
[
"langchain.agents.self_ask_with_search.base.SearchApiAPIWrapper",
"langchain_community.utilities.SearchApiAPIWrapper"
],
[
"langchain.agents.self_ask_with_search.base.SerpAPIWrapper",
"langchain_community.utilities.SerpAPIWrapper"
],
[ [
"langchain.cache.InMemoryCache", "langchain.cache.InMemoryCache",
"langchain_community.cache.InMemoryCache" "langchain_community.cache.InMemoryCache"
@ -955,14 +651,6 @@
"langchain.callbacks.sagemaker_callback.SageMakerCallbackHandler", "langchain.callbacks.sagemaker_callback.SageMakerCallbackHandler",
"langchain_community.callbacks.SageMakerCallbackHandler" "langchain_community.callbacks.SageMakerCallbackHandler"
], ],
[
"langchain.callbacks.streamlit.LLMThoughtLabeler",
"langchain_community.callbacks.LLMThoughtLabeler"
],
[
"langchain.callbacks.streamlit._InternalStreamlitCallbackHandler",
"langchain_community.callbacks.streamlit.streamlit_callback_handler.StreamlitCallbackHandler"
],
[ [
"langchain.callbacks.streamlit.mutable_expander.ChildType", "langchain.callbacks.streamlit.mutable_expander.ChildType",
"langchain_community.callbacks.streamlit.mutable_expander.ChildType" "langchain_community.callbacks.streamlit.mutable_expander.ChildType"
@ -1063,126 +751,6 @@
"langchain.callbacks.whylabs_callback.WhyLabsCallbackHandler", "langchain.callbacks.whylabs_callback.WhyLabsCallbackHandler",
"langchain_community.callbacks.WhyLabsCallbackHandler" "langchain_community.callbacks.WhyLabsCallbackHandler"
], ],
[
"langchain.chains.api.base.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.chains.api.openapi.chain.APIOperation",
"langchain_community.tools.APIOperation"
],
[
"langchain.chains.api.openapi.chain.Requests",
"langchain_community.utilities.Requests"
],
[
"langchain.chains.ernie_functions.base.JsonOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.JsonOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.PydanticAttrOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.PydanticAttrOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.PydanticOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.PydanticOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.convert_pydantic_to_ernie_function",
"langchain_community.utils.ernie_functions.convert_pydantic_to_ernie_function"
],
[
"langchain.chains.flare.base.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.chains.graph_qa.arangodb.ArangoGraph",
"langchain_community.graphs.ArangoGraph"
],
[
"langchain.chains.graph_qa.base.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.chains.graph_qa.base.get_entities",
"langchain_community.graphs.networkx_graph.get_entities"
],
[
"langchain.chains.graph_qa.cypher.GraphStore",
"langchain_community.graphs.graph_store.GraphStore"
],
[
"langchain.chains.graph_qa.falkordb.FalkorDBGraph",
"langchain_community.graphs.FalkorDBGraph"
],
[
"langchain.chains.graph_qa.gremlin.GremlinGraph",
"langchain_community.graphs.GremlinGraph"
],
[
"langchain.chains.graph_qa.hugegraph.HugeGraph",
"langchain_community.graphs.HugeGraph"
],
[
"langchain.chains.graph_qa.kuzu.KuzuGraph",
"langchain_community.graphs.KuzuGraph"
],
[
"langchain.chains.graph_qa.nebulagraph.NebulaGraph",
"langchain_community.graphs.NebulaGraph"
],
[
"langchain.chains.graph_qa.neptune_cypher.NeptuneGraph",
"langchain_community.graphs.NeptuneGraph"
],
[
"langchain.chains.graph_qa.neptune_sparql.NeptuneRdfGraph",
"langchain_community.graphs.NeptuneRdfGraph"
],
[
"langchain.chains.graph_qa.ontotext_graphdb.OntotextGraphDBGraph",
"langchain_community.graphs.OntotextGraphDBGraph"
],
[
"langchain.chains.graph_qa.sparql.RdfGraph",
"langchain_community.graphs.RdfGraph"
],
[
"langchain.chains.llm_requests.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.chains.loading.load_llm",
"langchain_community.llms.loading.load_llm"
],
[
"langchain.chains.loading.load_llm_from_config",
"langchain_community.llms.loading.load_llm_from_config"
],
[
"langchain.chains.natbot.base.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.chains.openai_functions.openapi.APIOperation",
"langchain_community.tools.APIOperation"
],
[
"langchain.chains.openai_functions.openapi.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.chains.openai_functions.openapi.OpenAPISpec",
"langchain_community.tools.OpenAPISpec"
],
[
"langchain.chains.router.multi_retrieval_qa.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.chains.sql_database.query.SQLDatabase",
"langchain_community.utilities.SQLDatabase"
],
[ [
"langchain.chat_loaders.base.BaseChatLoader", "langchain.chat_loaders.base.BaseChatLoader",
"langchain_community.chat_loaders.BaseChatLoader" "langchain_community.chat_loaders.BaseChatLoader"
@ -3695,34 +3263,6 @@
"langchain.embeddings.xinference.XinferenceEmbeddings", "langchain.embeddings.xinference.XinferenceEmbeddings",
"langchain_community.embeddings.XinferenceEmbeddings" "langchain_community.embeddings.XinferenceEmbeddings"
], ],
[
"langchain.evaluation.comparison.eval_chain.AzureChatOpenAI",
"langchain_community.chat_models.AzureChatOpenAI"
],
[
"langchain.evaluation.comparison.eval_chain.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.evaluation.embedding_distance.base.OpenAIEmbeddings",
"langchain_community.embeddings.OpenAIEmbeddings"
],
[
"langchain.evaluation.embedding_distance.base.cosine_similarity",
"langchain_community.utils.math.cosine_similarity"
],
[
"langchain.evaluation.loading.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.evaluation.scoring.eval_chain.AzureChatOpenAI",
"langchain_community.chat_models.AzureChatOpenAI"
],
[
"langchain.evaluation.scoring.eval_chain.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[ [
"langchain.graphs.MemgraphGraph", "langchain.graphs.MemgraphGraph",
"langchain_community.graphs.MemgraphGraph" "langchain_community.graphs.MemgraphGraph"
@ -3835,26 +3375,6 @@
"langchain.graphs.rdf_graph.RdfGraph", "langchain.graphs.rdf_graph.RdfGraph",
"langchain_community.graphs.RdfGraph" "langchain_community.graphs.RdfGraph"
], ],
[
"langchain.indexes.graph.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.indexes.graph.parse_triples",
"langchain_community.graphs.networkx_graph.parse_triples"
],
[
"langchain.indexes.vectorstore.Chroma",
"langchain_community.vectorstores.Chroma"
],
[
"langchain.indexes.vectorstore.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.indexes.vectorstore.OpenAIEmbeddings",
"langchain_community.embeddings.OpenAIEmbeddings"
],
[ [
"langchain.llms.AI21", "langchain.llms.AI21",
"langchain_community.llms.AI21" "langchain_community.llms.AI21"
@ -4655,10 +4175,6 @@
"langchain.memory.UpstashRedisChatMessageHistory", "langchain.memory.UpstashRedisChatMessageHistory",
"langchain_community.chat_message_histories.UpstashRedisChatMessageHistory" "langchain_community.chat_message_histories.UpstashRedisChatMessageHistory"
], ],
[
"langchain.memory.chat_memory.ChatMessageHistory",
"langchain_community.chat_message_histories.ChatMessageHistory"
],
[ [
"langchain.memory.chat_message_histories.AstraDBChatMessageHistory", "langchain.memory.chat_message_histories.AstraDBChatMessageHistory",
"langchain_community.chat_message_histories.AstraDBChatMessageHistory" "langchain_community.chat_message_histories.AstraDBChatMessageHistory"
@ -4827,30 +4343,6 @@
"langchain.memory.chat_message_histories.zep.ZepChatMessageHistory", "langchain.memory.chat_message_histories.zep.ZepChatMessageHistory",
"langchain_community.chat_message_histories.ZepChatMessageHistory" "langchain_community.chat_message_histories.ZepChatMessageHistory"
], ],
[
"langchain.memory.entity.get_client",
"langchain_community.utilities.redis.get_client"
],
[
"langchain.memory.kg.KnowledgeTriple",
"langchain_community.graphs.networkx_graph.KnowledgeTriple"
],
[
"langchain.memory.kg.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.memory.kg.get_entities",
"langchain_community.graphs.networkx_graph.get_entities"
],
[
"langchain.memory.kg.parse_triples",
"langchain_community.graphs.networkx_graph.parse_triples"
],
[
"langchain.memory.zep_memory.ZepChatMessageHistory",
"langchain_community.chat_message_histories.ZepChatMessageHistory"
],
[ [
"langchain.output_parsers.GuardrailsOutputParser", "langchain.output_parsers.GuardrailsOutputParser",
"langchain_community.output_parsers.rail_parser.GuardrailsOutputParser" "langchain_community.output_parsers.rail_parser.GuardrailsOutputParser"
@ -5103,18 +4595,6 @@
"langchain.retrievers.docarray.DocArrayRetriever", "langchain.retrievers.docarray.DocArrayRetriever",
"langchain_community.retrievers.DocArrayRetriever" "langchain_community.retrievers.DocArrayRetriever"
], ],
[
"langchain.retrievers.document_compressors.embeddings_filter._get_embeddings_from_stateful_docs",
"langchain_community.document_transformers.embeddings_redundant_filter._get_embeddings_from_stateful_docs"
],
[
"langchain.retrievers.document_compressors.embeddings_filter.cosine_similarity",
"langchain_community.utils.math.cosine_similarity"
],
[
"langchain.retrievers.document_compressors.embeddings_filter.get_stateful_documents",
"langchain_community.document_transformers.get_stateful_documents"
],
[ [
"langchain.retrievers.elastic_search_bm25.ElasticSearchBM25Retriever", "langchain.retrievers.elastic_search_bm25.ElasticSearchBM25Retriever",
"langchain_community.retrievers.ElasticSearchBM25Retriever" "langchain_community.retrievers.ElasticSearchBM25Retriever"
@ -5243,110 +4723,6 @@
"langchain.retrievers.remote_retriever.RemoteLangChainRetriever", "langchain.retrievers.remote_retriever.RemoteLangChainRetriever",
"langchain_community.retrievers.RemoteLangChainRetriever" "langchain_community.retrievers.RemoteLangChainRetriever"
], ],
[
"langchain.retrievers.self_query.base.AstraDB",
"langchain_community.vectorstores.AstraDB"
],
[
"langchain.retrievers.self_query.base.Chroma",
"langchain_community.vectorstores.Chroma"
],
[
"langchain.retrievers.self_query.base.DashVector",
"langchain_community.vectorstores.DashVector"
],
[
"langchain.retrievers.self_query.base.DeepLake",
"langchain_community.vectorstores.DeepLake"
],
[
"langchain.retrievers.self_query.base.Dingo",
"langchain_community.vectorstores.Dingo"
],
[
"langchain.retrievers.self_query.base.ElasticsearchStore",
"langchain_community.vectorstores.ElasticsearchStore"
],
[
"langchain.retrievers.self_query.base.Milvus",
"langchain_community.vectorstores.Milvus"
],
[
"langchain.retrievers.self_query.base.MongoDBAtlasVectorSearch",
"langchain_community.vectorstores.MongoDBAtlasVectorSearch"
],
[
"langchain.retrievers.self_query.base.MyScale",
"langchain_community.vectorstores.MyScale"
],
[
"langchain.retrievers.self_query.base.OpenSearchVectorSearch",
"langchain_community.vectorstores.OpenSearchVectorSearch"
],
[
"langchain.retrievers.self_query.base.PGVector",
"langchain_community.vectorstores.PGVector"
],
[
"langchain.retrievers.self_query.base.Pinecone",
"langchain_community.vectorstores.Pinecone"
],
[
"langchain.retrievers.self_query.base.Qdrant",
"langchain_community.vectorstores.Qdrant"
],
[
"langchain.retrievers.self_query.base.Redis",
"langchain_community.vectorstores.Redis"
],
[
"langchain.retrievers.self_query.base.SupabaseVectorStore",
"langchain_community.vectorstores.SupabaseVectorStore"
],
[
"langchain.retrievers.self_query.base.TimescaleVector",
"langchain_community.vectorstores.TimescaleVector"
],
[
"langchain.retrievers.self_query.base.Vectara",
"langchain_community.vectorstores.Vectara"
],
[
"langchain.retrievers.self_query.base.Weaviate",
"langchain_community.vectorstores.Weaviate"
],
[
"langchain.retrievers.self_query.redis.Redis",
"langchain_community.vectorstores.Redis"
],
[
"langchain.retrievers.self_query.redis.RedisFilterExpression",
"langchain_community.vectorstores.redis.filters.RedisFilterExpression"
],
[
"langchain.retrievers.self_query.redis.RedisFilterField",
"langchain_community.vectorstores.redis.filters.RedisFilterField"
],
[
"langchain.retrievers.self_query.redis.RedisFilterOperator",
"langchain_community.vectorstores.redis.filters.RedisFilterOperator"
],
[
"langchain.retrievers.self_query.redis.RedisModel",
"langchain_community.vectorstores.redis.schema.RedisModel"
],
[
"langchain.retrievers.self_query.redis.RedisNum",
"langchain_community.vectorstores.redis.filters.RedisNum"
],
[
"langchain.retrievers.self_query.redis.RedisTag",
"langchain_community.vectorstores.redis.filters.RedisTag"
],
[
"langchain.retrievers.self_query.redis.RedisText",
"langchain_community.vectorstores.redis.filters.RedisText"
],
[ [
"langchain.retrievers.svm.SVMRetriever", "langchain.retrievers.svm.SVMRetriever",
"langchain_community.retrievers.SVMRetriever" "langchain_community.retrievers.SVMRetriever"
@ -5371,22 +4747,6 @@
"langchain.retrievers.weaviate_hybrid_search.WeaviateHybridSearchRetriever", "langchain.retrievers.weaviate_hybrid_search.WeaviateHybridSearchRetriever",
"langchain_community.retrievers.WeaviateHybridSearchRetriever" "langchain_community.retrievers.WeaviateHybridSearchRetriever"
], ],
[
"langchain.retrievers.web_research.AsyncHtmlLoader",
"langchain_community.document_loaders.AsyncHtmlLoader"
],
[
"langchain.retrievers.web_research.GoogleSearchAPIWrapper",
"langchain_community.utilities.GoogleSearchAPIWrapper"
],
[
"langchain.retrievers.web_research.Html2TextTransformer",
"langchain_community.document_transformers.Html2TextTransformer"
],
[
"langchain.retrievers.web_research.LlamaCpp",
"langchain_community.llms.LlamaCpp"
],
[ [
"langchain.retrievers.wikipedia.WikipediaRetriever", "langchain.retrievers.wikipedia.WikipediaRetriever",
"langchain_community.retrievers.WikipediaRetriever" "langchain_community.retrievers.WikipediaRetriever"
@ -5439,10 +4799,6 @@
"langchain.storage.exceptions.InvalidKeyException", "langchain.storage.exceptions.InvalidKeyException",
"langchain_community.storage.exceptions.InvalidKeyException" "langchain_community.storage.exceptions.InvalidKeyException"
], ],
[
"langchain.storage.file_system.InvalidKeyException",
"langchain_community.storage.exceptions.InvalidKeyException"
],
[ [
"langchain.storage.redis.RedisStore", "langchain.storage.redis.RedisStore",
"langchain_community.storage.RedisStore" "langchain_community.storage.RedisStore"

View File

@ -0,0 +1,82 @@
[
[
"langchain.text_splitter.TokenTextSplitter",
"langchain_text_splitters.TokenTextSplitter"
],
[
"langchain.text_splitter.TextSplitter",
"langchain_text_splitters.TextSplitter"
],
[
"langchain.text_splitter.Tokenizer",
"langchain_text_splitters.Tokenizer"
],
[
"langchain.text_splitter.Language",
"langchain_text_splitters.Language"
],
[
"langchain.text_splitter.RecursiveCharacterTextSplitter",
"langchain_text_splitters.RecursiveCharacterTextSplitter"
],
[
"langchain.text_splitter.RecursiveJsonSplitter",
"langchain_text_splitters.RecursiveJsonSplitter"
],
[
"langchain.text_splitter.LatexTextSplitter",
"langchain_text_splitters.LatexTextSplitter"
],
[
"langchain.text_splitter.PythonCodeTextSplitter",
"langchain_text_splitters.PythonCodeTextSplitter"
],
[
"langchain.text_splitter.KonlpyTextSplitter",
"langchain_text_splitters.KonlpyTextSplitter"
],
[
"langchain.text_splitter.SpacyTextSplitter",
"langchain_text_splitters.SpacyTextSplitter"
],
[
"langchain.text_splitter.NLTKTextSplitter",
"langchain_text_splitters.NLTKTextSplitter"
],
[
"langchain.text_splitter.split_text_on_tokens",
"langchain_text_splitters.split_text_on_tokens"
],
[
"langchain.text_splitter.SentenceTransformersTokenTextSplitter",
"langchain_text_splitters.SentenceTransformersTokenTextSplitter"
],
[
"langchain.text_splitter.ElementType",
"langchain_text_splitters.ElementType"
],
[
"langchain.text_splitter.HeaderType",
"langchain_text_splitters.HeaderType"
],
[
"langchain.text_splitter.LineType",
"langchain_text_splitters.LineType"
],
[
"langchain.text_splitter.HTMLHeaderTextSplitter",
"langchain_text_splitters.HTMLHeaderTextSplitter"
],
[
"langchain.text_splitter.MarkdownHeaderTextSplitter",
"langchain_text_splitters.MarkdownHeaderTextSplitter"
],
[
"langchain.text_splitter.MarkdownTextSplitter",
"langchain_text_splitters.MarkdownTextSplitter"
],
[
"langchain.text_splitter.CharacterTextSplitter",
"langchain_text_splitters.CharacterTextSplitter"
]
]

View File

@ -15,11 +15,11 @@ from __future__ import annotations
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar from typing import Callable, Dict, Iterable, List, Sequence, Tuple, Type, TypeVar
import libcst as cst import libcst as cst
import libcst.matchers as m import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.codemod import VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor from libcst.codemod.visitors import AddImportsVisitor
HERE = os.path.dirname(__file__) HERE = os.path.dirname(__file__)
@ -43,18 +43,8 @@ def _deduplicate_in_order(
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))] return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
PARTNERS = [ def _load_migrations_from_fixtures(paths: List[str]) -> List[Tuple[str, str]]:
"anthropic.json",
"ibm.json",
"openai.json",
"pinecone.json",
"fireworks.json",
]
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
"""Load migrations from fixtures.""" """Load migrations from fixtures."""
paths: List[str] = PARTNERS + ["langchain_to_langchain_community.json"]
data = [] data = []
for path in paths: for path in paths:
data.extend(_load_migrations_by_file(path)) data.extend(_load_migrations_by_file(path))
@ -62,11 +52,11 @@ def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
return data return data
def _load_migrations(): def _load_migrations(paths: List[str]):
"""Load the migrations from the JSON file.""" """Load the migrations from the JSON file."""
# Later earlier ones have higher precedence. # Later earlier ones have higher precedence.
imports: Dict[str, Tuple[str, str]] = {} imports: Dict[str, Tuple[str, str]] = {}
data = _load_migrations_from_fixtures() data = _load_migrations_from_fixtures(paths)
for old_path, new_path in data: for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI' # Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
@ -88,9 +78,6 @@ def _load_migrations():
return imports return imports
IMPORTS = _load_migrations()
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name: def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
"""Converts a list of module parts to a `Name` or `Attribute` node.""" """Converts a list of module parts to a `Name` or `Attribute` node."""
if len(module_parts) == 1: if len(module_parts) == 1:
@ -139,76 +126,67 @@ class ImportInfo:
to_import_str: tuple[str, str] to_import_str: tuple[str, str]
IMPORT_INFOS = [ RULE_TO_PATHS = {
ImportInfo( "langchain_to_community": ["langchain_to_community.json"],
import_from=get_import_from_from_str(import_str), "langchain_to_core": ["langchain_to_core.json"],
import_str=import_str, "community_to_core": ["community_to_core.json"],
to_import_str=to_import_str, "langchain_to_text_splitters": ["langchain_to_text_splitters.json"],
) "community_to_partner": [
for import_str, to_import_str in IMPORTS.items() "anthropic.json",
] "fireworks.json",
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS]) "ibm.json",
"openai.json",
"pinecone.json",
],
}
class ReplaceImportsCodemod(VisitorBasedCodemodCommand): def generate_import_replacer(rules: List[str]) -> Type[VisitorBasedCodemodCommand]:
@m.leave(IMPORT_MATCH) """Generate a codemod to replace imports."""
def leave_replace_import( paths = []
self, _: cst.ImportFrom, updated_node: cst.ImportFrom for rule in rules:
) -> cst.ImportFrom: if rule not in RULE_TO_PATHS:
for import_info in IMPORT_INFOS: raise ValueError(f"Unknown rule: {rule}. Use one of {RULE_TO_PATHS.keys()}")
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node
paths.extend(RULE_TO_PATHS[rule])
if __name__ == "__main__": imports = _load_migrations(paths)
import textwrap
from rich.console import Console import_infos = [
ImportInfo(
import_from=get_import_from_from_str(import_str),
import_str=import_str,
to_import_str=to_import_str,
)
for import_str, to_import_str in imports.items()
]
import_match = m.OneOf(*[info.import_from for info in import_infos])
console = Console() class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(import_match)
def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
for import_info in import_infos:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names[-1] = names[-1].with_changes(
comma=cst.MaybeSentinel.DEFAULT
)
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node
source = textwrap.dedent( return ReplaceImportsCodemod
"""
from pydantic.settings import BaseSettings
from pydantic.color import Color
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
from pydantic import Color
from pydantic import Color as Potato
class Potato(BaseSettings):
color: Color
payment: PaymentCardNumber
brand: PaymentCardBrand
potato: Potato
"""
)
console.print(source)
console.print("=" * 80)
mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceImportsCodemod(context=context)
mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
def generate_raw_migrations( def generate_raw_migrations(
from_package: str, to_package: str from_package: str, to_package: str, filter_by_all: bool = False
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules.""" """Scan the `langchain` package and generate migrations for all modules."""
package = importlib.import_module(from_package) package = importlib.import_module(from_package)
@ -40,15 +40,17 @@ def generate_raw_migrations(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
) )
# Iterate over all members of the module if not filter_by_all:
for name, obj in inspect.getmembers(module): # Iterate over all members of the module
# Check if it's a class or function for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) or inspect.isfunction(obj): # Check if it's a class or function
# Check if the module name of the obj starts with 'langchain_community' if inspect.isclass(obj) or inspect.isfunction(obj):
if obj.__module__.startswith(to_package): # Check if the module name of the obj starts with
items.append( # 'langchain_community'
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") if obj.__module__.startswith(to_package):
) items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
)
return items return items
@ -77,45 +79,52 @@ def generate_top_level_imports(pkg: str) -> List[Tuple[str, str]]:
to importing it from the top level namespaces to importing it from the top level namespaces
(e.g., langchain_community.chat_models.XYZ) (e.g., langchain_community.chat_models.XYZ)
""" """
import importlib
package = importlib.import_module(pkg) package = importlib.import_module(pkg)
items = [] items = []
# Function to handle importing from modules
def handle_module(module, module_name):
if hasattr(module, "__all__"):
all_objects = getattr(module, "__all__")
for name in all_objects:
# Attempt to fetch each object declared in __all__
obj = getattr(module, name, None)
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
# Capture the fully qualified name of the object
original_module = obj.__module__
original_name = obj.__name__
# Form the new import path from the top-level namespace
top_level_import = f"{module_name}.{name}"
# Append the tuple with original and top-level paths
items.append(
(f"{original_module}.{original_name}", top_level_import)
)
# Handle the package itself (root level)
handle_module(package, pkg)
# Only iterate through top-level modules/packages # Only iterate through top-level modules/packages
for finder, modname, ispkg in pkgutil.iter_modules( for finder, modname, ispkg in pkgutil.iter_modules(
package.__path__, package.__name__ + "." package.__path__, package.__name__ + "."
): ):
if ispkg: if ispkg:
try: try:
module = importlib.import_module(modname) module = importlib.import_module(modname)
handle_module(module, modname)
except ModuleNotFoundError: except ModuleNotFoundError:
continue continue
if hasattr(module, "__all__"):
all_objects = getattr(module, "__all__")
for name in all_objects:
# Attempt to fetch each object declared in __all__
obj = getattr(module, name, None)
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
# Capture the fully qualified name of the object
original_module = obj.__module__
original_name = obj.__name__
# Form the new import path from the top-level namespace
top_level_import = f"{modname}.{name}"
# Append the tuple with original and top-level paths
items.append(
(f"{original_module}.{original_name}", top_level_import)
)
return items return items
def generate_simplified_migrations( def generate_simplified_migrations(
from_package: str, to_package: str from_package: str, to_package: str, filter_by_all: bool = True
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible.""" """Get all the raw migrations, then simplify them if possible."""
raw_migrations = generate_raw_migrations(from_package, to_package) raw_migrations = generate_raw_migrations(
from_package, to_package, filter_by_all=filter_by_all
)
top_level_simplifications = generate_top_level_imports(to_package) top_level_simplifications = generate_top_level_imports(to_package)
top_level_dict = {full: top_level for full, top_level in top_level_simplifications} top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
simple_migrations = [] simple_migrations = []

View File

@ -8,7 +8,7 @@ import os
import time import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import libcst as cst import libcst as cst
import rich import rich
@ -41,6 +41,9 @@ def main(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern." default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
), ),
log_file: Path = Option("log.txt", help="Log errors to this file."), log_file: Path = Option("log.txt", help="Log errors to this file."),
include_ipynb: bool = Option(
False, help="Include Jupyter Notebook files in the migration."
),
): ):
"""Migrate langchain to the most recent version.""" """Migrate langchain to the most recent version."""
if not diff: if not diff:
@ -63,6 +66,8 @@ def main(
else: else:
package = path package = path
all_files = sorted(package.glob("**/*.py")) all_files = sorted(package.glob("**/*.py"))
if include_ipynb:
all_files.extend(sorted(package.glob("**/*.ipynb")))
filtered_files = [ filtered_files = [
file file
@ -86,11 +91,9 @@ def main(
scratch: dict[str, Any] = {} scratch: dict[str, Any] = {}
start_time = time.time() start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8") log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial( partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff get_and_run_codemods, disable, metadata_manager, scratch, package, diff
) )
with Progress(*Progress.get_default_columns(), transient=True) as progress: with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files)) task = progress.add_task(description="Executing codemods...", total=len(files))
@ -127,6 +130,121 @@ def main(
raise Exit(1) raise Exit(1)
def get_and_run_codemods(
disabled_rules: List[Rule],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
"""Run codemods from rules.
Wrapper around run_codemods to be used with multiprocessing.Pool.
"""
codemods = gather_codemods(disabled=disabled_rules)
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename)
def _rewrite_file(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Union[str, None], Union[List[str], None]]:
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
def _rewrite_notebook(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Optional[str], Optional[List[str]]]:
"""Try to rewrite a Jupyter Notebook file."""
import nbformat
file_path = Path(filename)
if file_path.suffix != ".ipynb":
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.")
with file_path.open("r", encoding="utf-8") as fp:
notebook = nbformat.read(fp, as_version=4)
diffs = []
for cell in notebook.cells:
if cell.cell_type == "code":
code = "".join(cell.source)
# Skip code if any of the lines begin with a magic command or
# a ! command.
# We can try to handle later.
if any(
line.startswith("!") or line.startswith("%")
for line in code.splitlines()
):
continue
input_tree = cst.parse_module(code)
# TODO(Team): Quick hack, need to figure out
# how to handle this correctly.
# This prevents the code from trying to re-insert the imports
# for every cell in the notebook.
local_context = CodemodContext()
for codemod in codemods:
transformer = codemod(context=local_context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
cell.source = output_code.splitlines(keepends=True)
if diff:
cell_diff = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
diffs.extend(list(cell_diff))
if diff:
return None, diffs
with file_path.open("w", encoding="utf-8") as fp:
nbformat.write(notebook, fp)
return None, None
def run_codemods( def run_codemods(
codemods: List[Type[ContextAwareTransformer]], codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager, metadata_manager: FullRepoManager,
@ -145,32 +263,10 @@ def run_codemods(
) )
context.scratch.update(scratch) context.scratch.update(scratch)
file_path = Path(filename) if filename.endswith(".ipynb"):
with file_path.open("r+", encoding="utf-8") as fp: return _rewrite_notebook(filename, codemods, diff, context)
code = fp.read() else:
fp.seek(0) return _rewrite_file(filename, codemods, diff, context)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except cst.ParserSyntaxError as exc: except cst.ParserSyntaxError as exc:
return ( return (
f"A syntax error happened on {filename}. This file cannot be " f"A syntax error happened on {filename}. This file cannot be "

View File

@ -32,10 +32,15 @@ def cli():
default=None, default=None,
help="Output file for the migration script.", help="Output file for the migration script.",
) )
def generic(pkg1: str, pkg2: str, output: str) -> None: @click.option(
"--filter-by-all/--no-filter-by-all",
default=True,
help="Output file for the migration script.",
)
def generic(pkg1: str, pkg2: str, output: str, filter_by_all: bool) -> None:
"""Generate a migration script.""" """Generate a migration script."""
click.echo("Migration script generated.") click.echo("Migration script generated.")
migrations = generate_simplified_migrations(pkg1, pkg2) migrations = generate_simplified_migrations(pkg1, pkg2, filter_by_all=filter_by_all)
if output is None: if output is None:
output = f"{pkg1}_to_{pkg2}.json" output = f"{pkg1}_to_{pkg2}.json"

View File

@ -1,45 +1,55 @@
from langchain._api import suppress_langchain_deprecation_warning as sup2
from langchain_core._api import suppress_langchain_deprecation_warning as sup1
from langchain_cli.namespaces.migrate.generate.generic import ( from langchain_cli.namespaces.migrate.generate.generic import (
generate_simplified_migrations, generate_simplified_migrations,
generate_raw_migrations,
) )
def test_create_json_agent_migration() -> None: def test_create_json_agent_migration() -> None:
"""Test the migration of create_json_agent from langchain to langchain_community.""" """Test the migration of create_json_agent from langchain to langchain_community."""
raw_migrations = generate_simplified_migrations( with sup1():
from_package="langchain", to_package="langchain_community" with sup2():
) raw_migrations = generate_simplified_migrations(
json_agent_migrations = [ from_package="langchain", to_package="langchain_community"
migration for migration in raw_migrations if "create_json_agent" in migration[0] )
] json_agent_migrations = [
assert json_agent_migrations == [ migration
( for migration in raw_migrations
"langchain.agents.create_json_agent", if "create_json_agent" in migration[0]
"langchain_community.agent_toolkits.create_json_agent", ]
), assert json_agent_migrations == [
( (
"langchain.agents.agent_toolkits.create_json_agent", "langchain.agents.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent", "langchain_community.agent_toolkits.create_json_agent",
), ),
( (
"langchain.agents.agent_toolkits.json.base.create_json_agent", "langchain.agents.agent_toolkits.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent", "langchain_community.agent_toolkits.create_json_agent",
), ),
] (
"langchain.agents.agent_toolkits.json.base.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent",
),
]
def test_create_single_store_retriever_db() -> None: def test_create_single_store_retriever_db() -> None:
"""Test migration from langchain to langchain_core""" """Test migration from langchain to langchain_core"""
raw_migrations = generate_simplified_migrations( with sup1():
from_package="langchain", to_package="langchain_core" with sup2():
) raw_migrations = generate_simplified_migrations(
# SingleStore was an old name for VectorStoreRetriever from_package="langchain", to_package="langchain_core"
single_store_migration = [ )
migration for migration in raw_migrations if "SingleStore" in migration[0] # SingleStore was an old name for VectorStoreRetriever
] single_store_migration = [
assert single_store_migration == [ migration
( for migration in raw_migrations
"langchain.vectorstores.singlestoredb.SingleStoreDBRetriever", if "SingleStore" in migration[0]
"langchain_core.vectorstores.VectorStoreRetriever", ]
), assert single_store_migration == [
] (
"langchain.vectorstores.singlestoredb.SingleStoreDBRetriever",
"langchain_core.vectorstores.VectorStoreRetriever",
),
]

View File

@ -7,9 +7,18 @@ pytest.importorskip("libcst")
from libcst.codemod import CodemodTest from libcst.codemod import CodemodTest
from langchain_cli.namespaces.migrate.codemods.replace_imports import ( from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod, generate_import_replacer,
) )
ReplaceImportsCodemod = generate_import_replacer(
[
"langchain_to_community",
"community_to_partner",
"langchain_to_core",
"community_to_core",
]
) # type: ignore[attr-defined] # noqa: E501
class TestReplaceImportsCommand(CodemodTest): class TestReplaceImportsCommand(CodemodTest):
TRANSFORM = ReplaceImportsCodemod TRANSFORM = ReplaceImportsCodemod