cli[minor]: Add ipynb support, add text_splitters (#20963)

This commit is contained in:
Eugene Yurtsev 2024-04-29 10:11:21 -04:00 committed by GitHub
parent 5e0b6b3e75
commit d781560722
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2632 additions and 6525 deletions

View File

@ -5,21 +5,41 @@ from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from langchain_cli.namespaces.migrate.codemods.replace_imports import ( from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod, generate_import_replacer,
) )
class Rule(str, Enum): class Rule(str, Enum):
R001 = "R001" langchain_to_community = "langchain_to_community"
"""Replace imports that have been moved.""" """Replace deprecated langchain imports with current ones in community."""
langchain_to_core = "langchain_to_core"
"""Replace deprecated langchain imports with current ones in core."""
langchain_to_text_splitters = "langchain_to_text_splitters"
"""Replace deprecated langchain imports with current ones in text splitters."""
community_to_core = "community_to_core"
"""Replace deprecated community imports with current ones in core."""
community_to_partner = "community_to_partner"
"""Replace deprecated community imports with current ones in partner."""
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]: def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
"""Gather codemods based on the disabled rules."""
codemods: List[Type[ContextAwareTransformer]] = [] codemods: List[Type[ContextAwareTransformer]] = []
if Rule.R001 not in disabled: # Import rules
codemods.append(ReplaceImportsCodemod) import_rules = {
Rule.langchain_to_community,
Rule.langchain_to_core,
Rule.community_to_core,
Rule.community_to_partner,
Rule.langchain_to_text_splitters,
}
# Find active import rules
active_import_rules = import_rules - set(disabled)
if active_import_rules:
codemods.append(generate_import_replacer(active_import_rules))
# Those codemods need to be the last ones. # Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor]) codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods return codemods

View File

@ -0,0 +1,110 @@
[
[
"langchain_community.callbacks.tracers.ConsoleCallbackHandler",
"langchain_core.tracers.ConsoleCallbackHandler"
],
[
"langchain_community.callbacks.tracers.FunctionCallbackHandler",
"langchain_core.tracers.stdout.FunctionCallbackHandler"
],
[
"langchain_community.callbacks.tracers.LangChainTracer",
"langchain_core.tracers.LangChainTracer"
],
[
"langchain_community.callbacks.tracers.LangChainTracerV1",
"langchain_core.tracers.langchain_v1.LangChainTracerV1"
],
[
"langchain_community.docstore.document.Document",
"langchain_core.documents.Document"
],
[
"langchain_community.document_loaders.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.document_loaders.base.BaseBlobParser",
"langchain_core.document_loaders.BaseBlobParser"
],
[
"langchain_community.document_loaders.base.BaseLoader",
"langchain_core.document_loaders.BaseLoader"
],
[
"langchain_community.document_loaders.blob_loaders.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.blob_loaders.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.document_loaders.blob_loaders.schema.Blob",
"langchain_core.document_loaders.Blob"
],
[
"langchain_community.document_loaders.blob_loaders.schema.BlobLoader",
"langchain_core.document_loaders.BlobLoader"
],
[
"langchain_community.tools.BaseTool",
"langchain_core.tools.BaseTool"
],
[
"langchain_community.tools.StructuredTool",
"langchain_core.tools.StructuredTool"
],
[
"langchain_community.tools.Tool",
"langchain_core.tools.Tool"
],
[
"langchain_community.tools.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.tool",
"langchain_core.tools.tool"
],
[
"langchain_community.tools.convert_to_openai.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.convert_to_openai.format_tool_to_openai_tool",
"langchain_core.utils.function_calling.format_tool_to_openai_tool"
],
[
"langchain_community.tools.render.format_tool_to_openai_function",
"langchain_core.utils.function_calling.format_tool_to_openai_function"
],
[
"langchain_community.tools.render.format_tool_to_openai_tool",
"langchain_core.utils.function_calling.format_tool_to_openai_tool"
],
[
"langchain_community.utils.openai_functions.FunctionDescription",
"langchain_core.utils.function_calling.FunctionDescription"
],
[
"langchain_community.utils.openai_functions.ToolDescription",
"langchain_core.utils.function_calling.ToolDescription"
],
[
"langchain_community.utils.openai_functions.convert_pydantic_to_openai_function",
"langchain_core.utils.function_calling.convert_pydantic_to_openai_function"
],
[
"langchain_community.utils.openai_functions.convert_pydantic_to_openai_tool",
"langchain_core.utils.function_calling.convert_pydantic_to_openai_tool"
],
[
"langchain_community.vectorstores.VectorStore",
"langchain_core.vectorstores.VectorStore"
]
]

View File

@ -383,314 +383,10 @@
"langchain.agents.agent_toolkits.steam.toolkit.SteamToolkit", "langchain.agents.agent_toolkits.steam.toolkit.SteamToolkit",
"langchain_community.agent_toolkits.SteamToolkit" "langchain_community.agent_toolkits.SteamToolkit"
], ],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.BaseToolkit",
"langchain_community.agent_toolkits.base.BaseToolkit"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.VectorStoreQATool",
"langchain_community.tools.VectorStoreQATool"
],
[
"langchain.agents.agent_toolkits.vectorstore.toolkit.VectorStoreQAWithSourcesTool",
"langchain_community.tools.VectorStoreQAWithSourcesTool"
],
[ [
"langchain.agents.agent_toolkits.zapier.toolkit.ZapierToolkit", "langchain.agents.agent_toolkits.zapier.toolkit.ZapierToolkit",
"langchain_community.agent_toolkits.ZapierToolkit" "langchain_community.agent_toolkits.ZapierToolkit"
], ],
[
"langchain.agents.load_tools.ArxivAPIWrapper",
"langchain_community.utilities.ArxivAPIWrapper"
],
[
"langchain.agents.load_tools.ArxivQueryRun",
"langchain_community.tools.ArxivQueryRun"
],
[
"langchain.agents.load_tools.BaseGraphQLTool",
"langchain_community.tools.BaseGraphQLTool"
],
[
"langchain.agents.load_tools.BingSearchAPIWrapper",
"langchain_community.utilities.BingSearchAPIWrapper"
],
[
"langchain.agents.load_tools.BingSearchRun",
"langchain_community.tools.BingSearchRun"
],
[
"langchain.agents.load_tools.DallEAPIWrapper",
"langchain_community.utilities.dalle_image_generator.DallEAPIWrapper"
],
[
"langchain.agents.load_tools.DataForSeoAPISearchResults",
"langchain_community.tools.dataforseo_api_search.tool.DataForSeoAPISearchResults"
],
[
"langchain.agents.load_tools.DataForSeoAPISearchRun",
"langchain_community.tools.dataforseo_api_search.tool.DataForSeoAPISearchRun"
],
[
"langchain.agents.load_tools.DataForSeoAPIWrapper",
"langchain_community.utilities.dataforseo_api_search.DataForSeoAPIWrapper"
],
[
"langchain.agents.load_tools.DuckDuckGoSearchAPIWrapper",
"langchain_community.utilities.DuckDuckGoSearchAPIWrapper"
],
[
"langchain.agents.load_tools.DuckDuckGoSearchRun",
"langchain_community.tools.DuckDuckGoSearchRun"
],
[
"langchain.agents.load_tools.ElevenLabsText2SpeechTool",
"langchain_community.tools.ElevenLabsText2SpeechTool"
],
[
"langchain.agents.load_tools.GoldenQueryAPIWrapper",
"langchain_community.utilities.GoldenQueryAPIWrapper"
],
[
"langchain.agents.load_tools.GoldenQueryRun",
"langchain_community.tools.golden_query.tool.GoldenQueryRun"
],
[
"langchain.agents.load_tools.GoogleCloudTextToSpeechTool",
"langchain_community.tools.GoogleCloudTextToSpeechTool"
],
[
"langchain.agents.load_tools.GoogleFinanceAPIWrapper",
"langchain_community.utilities.GoogleFinanceAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleFinanceQueryRun",
"langchain_community.tools.google_finance.tool.GoogleFinanceQueryRun"
],
[
"langchain.agents.load_tools.GoogleJobsAPIWrapper",
"langchain_community.utilities.GoogleJobsAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleJobsQueryRun",
"langchain_community.tools.google_jobs.tool.GoogleJobsQueryRun"
],
[
"langchain.agents.load_tools.GoogleLensAPIWrapper",
"langchain_community.utilities.GoogleLensAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleLensQueryRun",
"langchain_community.tools.google_lens.tool.GoogleLensQueryRun"
],
[
"langchain.agents.load_tools.GoogleScholarAPIWrapper",
"langchain_community.utilities.GoogleScholarAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleScholarQueryRun",
"langchain_community.tools.google_scholar.tool.GoogleScholarQueryRun"
],
[
"langchain.agents.load_tools.GoogleSearchAPIWrapper",
"langchain_community.utilities.GoogleSearchAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleSearchResults",
"langchain_community.tools.GoogleSearchResults"
],
[
"langchain.agents.load_tools.GoogleSearchRun",
"langchain_community.tools.GoogleSearchRun"
],
[
"langchain.agents.load_tools.GoogleSerperAPIWrapper",
"langchain_community.utilities.GoogleSerperAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleSerperResults",
"langchain_community.tools.GoogleSerperResults"
],
[
"langchain.agents.load_tools.GoogleSerperRun",
"langchain_community.tools.GoogleSerperRun"
],
[
"langchain.agents.load_tools.GoogleTrendsAPIWrapper",
"langchain_community.utilities.GoogleTrendsAPIWrapper"
],
[
"langchain.agents.load_tools.GoogleTrendsQueryRun",
"langchain_community.tools.google_trends.tool.GoogleTrendsQueryRun"
],
[
"langchain.agents.load_tools.GraphQLAPIWrapper",
"langchain_community.utilities.GraphQLAPIWrapper"
],
[
"langchain.agents.load_tools.HumanInputRun",
"langchain_community.tools.HumanInputRun"
],
[
"langchain.agents.load_tools.LambdaWrapper",
"langchain_community.utilities.LambdaWrapper"
],
[
"langchain.agents.load_tools.Memorize",
"langchain_community.tools.memorize.tool.Memorize"
],
[
"langchain.agents.load_tools.MerriamWebsterAPIWrapper",
"langchain_community.utilities.MerriamWebsterAPIWrapper"
],
[
"langchain.agents.load_tools.MerriamWebsterQueryRun",
"langchain_community.tools.MerriamWebsterQueryRun"
],
[
"langchain.agents.load_tools.MetaphorSearchAPIWrapper",
"langchain_community.utilities.MetaphorSearchAPIWrapper"
],
[
"langchain.agents.load_tools.MetaphorSearchResults",
"langchain_community.tools.MetaphorSearchResults"
],
[
"langchain.agents.load_tools.OpenWeatherMapAPIWrapper",
"langchain_community.utilities.OpenWeatherMapAPIWrapper"
],
[
"langchain.agents.load_tools.OpenWeatherMapQueryRun",
"langchain_community.tools.OpenWeatherMapQueryRun"
],
[
"langchain.agents.load_tools.PubMedAPIWrapper",
"langchain_community.utilities.PubMedAPIWrapper"
],
[
"langchain.agents.load_tools.PubmedQueryRun",
"langchain_community.tools.PubmedQueryRun"
],
[
"langchain.agents.load_tools.RedditSearchAPIWrapper",
"langchain_community.utilities.reddit_search.RedditSearchAPIWrapper"
],
[
"langchain.agents.load_tools.RedditSearchRun",
"langchain_community.tools.RedditSearchRun"
],
[
"langchain.agents.load_tools.RequestsDeleteTool",
"langchain_community.tools.RequestsDeleteTool"
],
[
"langchain.agents.load_tools.RequestsGetTool",
"langchain_community.tools.RequestsGetTool"
],
[
"langchain.agents.load_tools.RequestsPatchTool",
"langchain_community.tools.RequestsPatchTool"
],
[
"langchain.agents.load_tools.RequestsPostTool",
"langchain_community.tools.RequestsPostTool"
],
[
"langchain.agents.load_tools.RequestsPutTool",
"langchain_community.tools.RequestsPutTool"
],
[
"langchain.agents.load_tools.SceneXplainTool",
"langchain_community.tools.SceneXplainTool"
],
[
"langchain.agents.load_tools.SearchAPIResults",
"langchain_community.tools.SearchAPIResults"
],
[
"langchain.agents.load_tools.SearchAPIRun",
"langchain_community.tools.SearchAPIRun"
],
[
"langchain.agents.load_tools.SearchApiAPIWrapper",
"langchain_community.utilities.SearchApiAPIWrapper"
],
[
"langchain.agents.load_tools.SearxSearchResults",
"langchain_community.tools.SearxSearchResults"
],
[
"langchain.agents.load_tools.SearxSearchRun",
"langchain_community.tools.SearxSearchRun"
],
[
"langchain.agents.load_tools.SearxSearchWrapper",
"langchain_community.utilities.SearxSearchWrapper"
],
[
"langchain.agents.load_tools.SerpAPIWrapper",
"langchain_community.utilities.SerpAPIWrapper"
],
[
"langchain.agents.load_tools.ShellTool",
"langchain_community.tools.ShellTool"
],
[
"langchain.agents.load_tools.SleepTool",
"langchain_community.tools.SleepTool"
],
[
"langchain.agents.load_tools.StackExchangeAPIWrapper",
"langchain_community.utilities.StackExchangeAPIWrapper"
],
[
"langchain.agents.load_tools.StackExchangeTool",
"langchain_community.tools.StackExchangeTool"
],
[
"langchain.agents.load_tools.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.agents.load_tools.TwilioAPIWrapper",
"langchain_community.utilities.TwilioAPIWrapper"
],
[
"langchain.agents.load_tools.WikipediaAPIWrapper",
"langchain_community.utilities.WikipediaAPIWrapper"
],
[
"langchain.agents.load_tools.WikipediaQueryRun",
"langchain_community.tools.WikipediaQueryRun"
],
[
"langchain.agents.load_tools.WolframAlphaAPIWrapper",
"langchain_community.utilities.WolframAlphaAPIWrapper"
],
[
"langchain.agents.load_tools.WolframAlphaQueryRun",
"langchain_community.tools.WolframAlphaQueryRun"
],
[
"langchain.agents.react.base.Docstore",
"langchain_community.docstore.base.Docstore"
],
[
"langchain.agents.self_ask_with_search.base.GoogleSerperAPIWrapper",
"langchain_community.utilities.GoogleSerperAPIWrapper"
],
[
"langchain.agents.self_ask_with_search.base.SearchApiAPIWrapper",
"langchain_community.utilities.SearchApiAPIWrapper"
],
[
"langchain.agents.self_ask_with_search.base.SerpAPIWrapper",
"langchain_community.utilities.SerpAPIWrapper"
],
[ [
"langchain.cache.InMemoryCache", "langchain.cache.InMemoryCache",
"langchain_community.cache.InMemoryCache" "langchain_community.cache.InMemoryCache"
@ -955,14 +651,6 @@
"langchain.callbacks.sagemaker_callback.SageMakerCallbackHandler", "langchain.callbacks.sagemaker_callback.SageMakerCallbackHandler",
"langchain_community.callbacks.SageMakerCallbackHandler" "langchain_community.callbacks.SageMakerCallbackHandler"
], ],
[
"langchain.callbacks.streamlit.LLMThoughtLabeler",
"langchain_community.callbacks.LLMThoughtLabeler"
],
[
"langchain.callbacks.streamlit._InternalStreamlitCallbackHandler",
"langchain_community.callbacks.streamlit.streamlit_callback_handler.StreamlitCallbackHandler"
],
[ [
"langchain.callbacks.streamlit.mutable_expander.ChildType", "langchain.callbacks.streamlit.mutable_expander.ChildType",
"langchain_community.callbacks.streamlit.mutable_expander.ChildType" "langchain_community.callbacks.streamlit.mutable_expander.ChildType"
@ -1063,126 +751,6 @@
"langchain.callbacks.whylabs_callback.WhyLabsCallbackHandler", "langchain.callbacks.whylabs_callback.WhyLabsCallbackHandler",
"langchain_community.callbacks.WhyLabsCallbackHandler" "langchain_community.callbacks.WhyLabsCallbackHandler"
], ],
[
"langchain.chains.api.base.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.chains.api.openapi.chain.APIOperation",
"langchain_community.tools.APIOperation"
],
[
"langchain.chains.api.openapi.chain.Requests",
"langchain_community.utilities.Requests"
],
[
"langchain.chains.ernie_functions.base.JsonOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.JsonOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.PydanticAttrOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.PydanticAttrOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.PydanticOutputFunctionsParser",
"langchain_community.output_parsers.ernie_functions.PydanticOutputFunctionsParser"
],
[
"langchain.chains.ernie_functions.base.convert_pydantic_to_ernie_function",
"langchain_community.utils.ernie_functions.convert_pydantic_to_ernie_function"
],
[
"langchain.chains.flare.base.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.chains.graph_qa.arangodb.ArangoGraph",
"langchain_community.graphs.ArangoGraph"
],
[
"langchain.chains.graph_qa.base.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.chains.graph_qa.base.get_entities",
"langchain_community.graphs.networkx_graph.get_entities"
],
[
"langchain.chains.graph_qa.cypher.GraphStore",
"langchain_community.graphs.graph_store.GraphStore"
],
[
"langchain.chains.graph_qa.falkordb.FalkorDBGraph",
"langchain_community.graphs.FalkorDBGraph"
],
[
"langchain.chains.graph_qa.gremlin.GremlinGraph",
"langchain_community.graphs.GremlinGraph"
],
[
"langchain.chains.graph_qa.hugegraph.HugeGraph",
"langchain_community.graphs.HugeGraph"
],
[
"langchain.chains.graph_qa.kuzu.KuzuGraph",
"langchain_community.graphs.KuzuGraph"
],
[
"langchain.chains.graph_qa.nebulagraph.NebulaGraph",
"langchain_community.graphs.NebulaGraph"
],
[
"langchain.chains.graph_qa.neptune_cypher.NeptuneGraph",
"langchain_community.graphs.NeptuneGraph"
],
[
"langchain.chains.graph_qa.neptune_sparql.NeptuneRdfGraph",
"langchain_community.graphs.NeptuneRdfGraph"
],
[
"langchain.chains.graph_qa.ontotext_graphdb.OntotextGraphDBGraph",
"langchain_community.graphs.OntotextGraphDBGraph"
],
[
"langchain.chains.graph_qa.sparql.RdfGraph",
"langchain_community.graphs.RdfGraph"
],
[
"langchain.chains.llm_requests.TextRequestsWrapper",
"langchain_community.utilities.TextRequestsWrapper"
],
[
"langchain.chains.loading.load_llm",
"langchain_community.llms.loading.load_llm"
],
[
"langchain.chains.loading.load_llm_from_config",
"langchain_community.llms.loading.load_llm_from_config"
],
[
"langchain.chains.natbot.base.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.chains.openai_functions.openapi.APIOperation",
"langchain_community.tools.APIOperation"
],
[
"langchain.chains.openai_functions.openapi.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.chains.openai_functions.openapi.OpenAPISpec",
"langchain_community.tools.OpenAPISpec"
],
[
"langchain.chains.router.multi_retrieval_qa.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.chains.sql_database.query.SQLDatabase",
"langchain_community.utilities.SQLDatabase"
],
[ [
"langchain.chat_loaders.base.BaseChatLoader", "langchain.chat_loaders.base.BaseChatLoader",
"langchain_community.chat_loaders.BaseChatLoader" "langchain_community.chat_loaders.BaseChatLoader"
@ -3695,34 +3263,6 @@
"langchain.embeddings.xinference.XinferenceEmbeddings", "langchain.embeddings.xinference.XinferenceEmbeddings",
"langchain_community.embeddings.XinferenceEmbeddings" "langchain_community.embeddings.XinferenceEmbeddings"
], ],
[
"langchain.evaluation.comparison.eval_chain.AzureChatOpenAI",
"langchain_community.chat_models.AzureChatOpenAI"
],
[
"langchain.evaluation.comparison.eval_chain.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.evaluation.embedding_distance.base.OpenAIEmbeddings",
"langchain_community.embeddings.OpenAIEmbeddings"
],
[
"langchain.evaluation.embedding_distance.base.cosine_similarity",
"langchain_community.utils.math.cosine_similarity"
],
[
"langchain.evaluation.loading.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[
"langchain.evaluation.scoring.eval_chain.AzureChatOpenAI",
"langchain_community.chat_models.AzureChatOpenAI"
],
[
"langchain.evaluation.scoring.eval_chain.ChatOpenAI",
"langchain_community.chat_models.ChatOpenAI"
],
[ [
"langchain.graphs.MemgraphGraph", "langchain.graphs.MemgraphGraph",
"langchain_community.graphs.MemgraphGraph" "langchain_community.graphs.MemgraphGraph"
@ -3835,26 +3375,6 @@
"langchain.graphs.rdf_graph.RdfGraph", "langchain.graphs.rdf_graph.RdfGraph",
"langchain_community.graphs.RdfGraph" "langchain_community.graphs.RdfGraph"
], ],
[
"langchain.indexes.graph.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.indexes.graph.parse_triples",
"langchain_community.graphs.networkx_graph.parse_triples"
],
[
"langchain.indexes.vectorstore.Chroma",
"langchain_community.vectorstores.Chroma"
],
[
"langchain.indexes.vectorstore.OpenAI",
"langchain_community.llms.OpenAI"
],
[
"langchain.indexes.vectorstore.OpenAIEmbeddings",
"langchain_community.embeddings.OpenAIEmbeddings"
],
[ [
"langchain.llms.AI21", "langchain.llms.AI21",
"langchain_community.llms.AI21" "langchain_community.llms.AI21"
@ -4655,10 +4175,6 @@
"langchain.memory.UpstashRedisChatMessageHistory", "langchain.memory.UpstashRedisChatMessageHistory",
"langchain_community.chat_message_histories.UpstashRedisChatMessageHistory" "langchain_community.chat_message_histories.UpstashRedisChatMessageHistory"
], ],
[
"langchain.memory.chat_memory.ChatMessageHistory",
"langchain_community.chat_message_histories.ChatMessageHistory"
],
[ [
"langchain.memory.chat_message_histories.AstraDBChatMessageHistory", "langchain.memory.chat_message_histories.AstraDBChatMessageHistory",
"langchain_community.chat_message_histories.AstraDBChatMessageHistory" "langchain_community.chat_message_histories.AstraDBChatMessageHistory"
@ -4827,30 +4343,6 @@
"langchain.memory.chat_message_histories.zep.ZepChatMessageHistory", "langchain.memory.chat_message_histories.zep.ZepChatMessageHistory",
"langchain_community.chat_message_histories.ZepChatMessageHistory" "langchain_community.chat_message_histories.ZepChatMessageHistory"
], ],
[
"langchain.memory.entity.get_client",
"langchain_community.utilities.redis.get_client"
],
[
"langchain.memory.kg.KnowledgeTriple",
"langchain_community.graphs.networkx_graph.KnowledgeTriple"
],
[
"langchain.memory.kg.NetworkxEntityGraph",
"langchain_community.graphs.NetworkxEntityGraph"
],
[
"langchain.memory.kg.get_entities",
"langchain_community.graphs.networkx_graph.get_entities"
],
[
"langchain.memory.kg.parse_triples",
"langchain_community.graphs.networkx_graph.parse_triples"
],
[
"langchain.memory.zep_memory.ZepChatMessageHistory",
"langchain_community.chat_message_histories.ZepChatMessageHistory"
],
[ [
"langchain.output_parsers.GuardrailsOutputParser", "langchain.output_parsers.GuardrailsOutputParser",
"langchain_community.output_parsers.rail_parser.GuardrailsOutputParser" "langchain_community.output_parsers.rail_parser.GuardrailsOutputParser"
@ -5103,18 +4595,6 @@
"langchain.retrievers.docarray.DocArrayRetriever", "langchain.retrievers.docarray.DocArrayRetriever",
"langchain_community.retrievers.DocArrayRetriever" "langchain_community.retrievers.DocArrayRetriever"
], ],
[
"langchain.retrievers.document_compressors.embeddings_filter._get_embeddings_from_stateful_docs",
"langchain_community.document_transformers.embeddings_redundant_filter._get_embeddings_from_stateful_docs"
],
[
"langchain.retrievers.document_compressors.embeddings_filter.cosine_similarity",
"langchain_community.utils.math.cosine_similarity"
],
[
"langchain.retrievers.document_compressors.embeddings_filter.get_stateful_documents",
"langchain_community.document_transformers.get_stateful_documents"
],
[ [
"langchain.retrievers.elastic_search_bm25.ElasticSearchBM25Retriever", "langchain.retrievers.elastic_search_bm25.ElasticSearchBM25Retriever",
"langchain_community.retrievers.ElasticSearchBM25Retriever" "langchain_community.retrievers.ElasticSearchBM25Retriever"
@ -5243,110 +4723,6 @@
"langchain.retrievers.remote_retriever.RemoteLangChainRetriever", "langchain.retrievers.remote_retriever.RemoteLangChainRetriever",
"langchain_community.retrievers.RemoteLangChainRetriever" "langchain_community.retrievers.RemoteLangChainRetriever"
], ],
[
"langchain.retrievers.self_query.base.AstraDB",
"langchain_community.vectorstores.AstraDB"
],
[
"langchain.retrievers.self_query.base.Chroma",
"langchain_community.vectorstores.Chroma"
],
[
"langchain.retrievers.self_query.base.DashVector",
"langchain_community.vectorstores.DashVector"
],
[
"langchain.retrievers.self_query.base.DeepLake",
"langchain_community.vectorstores.DeepLake"
],
[
"langchain.retrievers.self_query.base.Dingo",
"langchain_community.vectorstores.Dingo"
],
[
"langchain.retrievers.self_query.base.ElasticsearchStore",
"langchain_community.vectorstores.ElasticsearchStore"
],
[
"langchain.retrievers.self_query.base.Milvus",
"langchain_community.vectorstores.Milvus"
],
[
"langchain.retrievers.self_query.base.MongoDBAtlasVectorSearch",
"langchain_community.vectorstores.MongoDBAtlasVectorSearch"
],
[
"langchain.retrievers.self_query.base.MyScale",
"langchain_community.vectorstores.MyScale"
],
[
"langchain.retrievers.self_query.base.OpenSearchVectorSearch",
"langchain_community.vectorstores.OpenSearchVectorSearch"
],
[
"langchain.retrievers.self_query.base.PGVector",
"langchain_community.vectorstores.PGVector"
],
[
"langchain.retrievers.self_query.base.Pinecone",
"langchain_community.vectorstores.Pinecone"
],
[
"langchain.retrievers.self_query.base.Qdrant",
"langchain_community.vectorstores.Qdrant"
],
[
"langchain.retrievers.self_query.base.Redis",
"langchain_community.vectorstores.Redis"
],
[
"langchain.retrievers.self_query.base.SupabaseVectorStore",
"langchain_community.vectorstores.SupabaseVectorStore"
],
[
"langchain.retrievers.self_query.base.TimescaleVector",
"langchain_community.vectorstores.TimescaleVector"
],
[
"langchain.retrievers.self_query.base.Vectara",
"langchain_community.vectorstores.Vectara"
],
[
"langchain.retrievers.self_query.base.Weaviate",
"langchain_community.vectorstores.Weaviate"
],
[
"langchain.retrievers.self_query.redis.Redis",
"langchain_community.vectorstores.Redis"
],
[
"langchain.retrievers.self_query.redis.RedisFilterExpression",
"langchain_community.vectorstores.redis.filters.RedisFilterExpression"
],
[
"langchain.retrievers.self_query.redis.RedisFilterField",
"langchain_community.vectorstores.redis.filters.RedisFilterField"
],
[
"langchain.retrievers.self_query.redis.RedisFilterOperator",
"langchain_community.vectorstores.redis.filters.RedisFilterOperator"
],
[
"langchain.retrievers.self_query.redis.RedisModel",
"langchain_community.vectorstores.redis.schema.RedisModel"
],
[
"langchain.retrievers.self_query.redis.RedisNum",
"langchain_community.vectorstores.redis.filters.RedisNum"
],
[
"langchain.retrievers.self_query.redis.RedisTag",
"langchain_community.vectorstores.redis.filters.RedisTag"
],
[
"langchain.retrievers.self_query.redis.RedisText",
"langchain_community.vectorstores.redis.filters.RedisText"
],
[ [
"langchain.retrievers.svm.SVMRetriever", "langchain.retrievers.svm.SVMRetriever",
"langchain_community.retrievers.SVMRetriever" "langchain_community.retrievers.SVMRetriever"
@ -5371,22 +4747,6 @@
"langchain.retrievers.weaviate_hybrid_search.WeaviateHybridSearchRetriever", "langchain.retrievers.weaviate_hybrid_search.WeaviateHybridSearchRetriever",
"langchain_community.retrievers.WeaviateHybridSearchRetriever" "langchain_community.retrievers.WeaviateHybridSearchRetriever"
], ],
[
"langchain.retrievers.web_research.AsyncHtmlLoader",
"langchain_community.document_loaders.AsyncHtmlLoader"
],
[
"langchain.retrievers.web_research.GoogleSearchAPIWrapper",
"langchain_community.utilities.GoogleSearchAPIWrapper"
],
[
"langchain.retrievers.web_research.Html2TextTransformer",
"langchain_community.document_transformers.Html2TextTransformer"
],
[
"langchain.retrievers.web_research.LlamaCpp",
"langchain_community.llms.LlamaCpp"
],
[ [
"langchain.retrievers.wikipedia.WikipediaRetriever", "langchain.retrievers.wikipedia.WikipediaRetriever",
"langchain_community.retrievers.WikipediaRetriever" "langchain_community.retrievers.WikipediaRetriever"
@ -5439,10 +4799,6 @@
"langchain.storage.exceptions.InvalidKeyException", "langchain.storage.exceptions.InvalidKeyException",
"langchain_community.storage.exceptions.InvalidKeyException" "langchain_community.storage.exceptions.InvalidKeyException"
], ],
[
"langchain.storage.file_system.InvalidKeyException",
"langchain_community.storage.exceptions.InvalidKeyException"
],
[ [
"langchain.storage.redis.RedisStore", "langchain.storage.redis.RedisStore",
"langchain_community.storage.RedisStore" "langchain_community.storage.RedisStore"

View File

@ -0,0 +1,82 @@
[
[
"langchain.text_splitter.TokenTextSplitter",
"langchain_text_splitters.TokenTextSplitter"
],
[
"langchain.text_splitter.TextSplitter",
"langchain_text_splitters.TextSplitter"
],
[
"langchain.text_splitter.Tokenizer",
"langchain_text_splitters.Tokenizer"
],
[
"langchain.text_splitter.Language",
"langchain_text_splitters.Language"
],
[
"langchain.text_splitter.RecursiveCharacterTextSplitter",
"langchain_text_splitters.RecursiveCharacterTextSplitter"
],
[
"langchain.text_splitter.RecursiveJsonSplitter",
"langchain_text_splitters.RecursiveJsonSplitter"
],
[
"langchain.text_splitter.LatexTextSplitter",
"langchain_text_splitters.LatexTextSplitter"
],
[
"langchain.text_splitter.PythonCodeTextSplitter",
"langchain_text_splitters.PythonCodeTextSplitter"
],
[
"langchain.text_splitter.KonlpyTextSplitter",
"langchain_text_splitters.KonlpyTextSplitter"
],
[
"langchain.text_splitter.SpacyTextSplitter",
"langchain_text_splitters.SpacyTextSplitter"
],
[
"langchain.text_splitter.NLTKTextSplitter",
"langchain_text_splitters.NLTKTextSplitter"
],
[
"langchain.text_splitter.split_text_on_tokens",
"langchain_text_splitters.split_text_on_tokens"
],
[
"langchain.text_splitter.SentenceTransformersTokenTextSplitter",
"langchain_text_splitters.SentenceTransformersTokenTextSplitter"
],
[
"langchain.text_splitter.ElementType",
"langchain_text_splitters.ElementType"
],
[
"langchain.text_splitter.HeaderType",
"langchain_text_splitters.HeaderType"
],
[
"langchain.text_splitter.LineType",
"langchain_text_splitters.LineType"
],
[
"langchain.text_splitter.HTMLHeaderTextSplitter",
"langchain_text_splitters.HTMLHeaderTextSplitter"
],
[
"langchain.text_splitter.MarkdownHeaderTextSplitter",
"langchain_text_splitters.MarkdownHeaderTextSplitter"
],
[
"langchain.text_splitter.MarkdownTextSplitter",
"langchain_text_splitters.MarkdownTextSplitter"
],
[
"langchain.text_splitter.CharacterTextSplitter",
"langchain_text_splitters.CharacterTextSplitter"
]
]

View File

@ -15,11 +15,11 @@ from __future__ import annotations
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar from typing import Callable, Dict, Iterable, List, Sequence, Tuple, Type, TypeVar
import libcst as cst import libcst as cst
import libcst.matchers as m import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.codemod import VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor from libcst.codemod.visitors import AddImportsVisitor
HERE = os.path.dirname(__file__) HERE = os.path.dirname(__file__)
@ -43,18 +43,8 @@ def _deduplicate_in_order(
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))] return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
PARTNERS = [ def _load_migrations_from_fixtures(paths: List[str]) -> List[Tuple[str, str]]:
"anthropic.json",
"ibm.json",
"openai.json",
"pinecone.json",
"fireworks.json",
]
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
"""Load migrations from fixtures.""" """Load migrations from fixtures."""
paths: List[str] = PARTNERS + ["langchain_to_langchain_community.json"]
data = [] data = []
for path in paths: for path in paths:
data.extend(_load_migrations_by_file(path)) data.extend(_load_migrations_by_file(path))
@ -62,11 +52,11 @@ def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
return data return data
def _load_migrations(): def _load_migrations(paths: List[str]):
"""Load the migrations from the JSON file.""" """Load the migrations from the JSON file."""
# Later earlier ones have higher precedence. # Later earlier ones have higher precedence.
imports: Dict[str, Tuple[str, str]] = {} imports: Dict[str, Tuple[str, str]] = {}
data = _load_migrations_from_fixtures() data = _load_migrations_from_fixtures(paths)
for old_path, new_path in data: for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI' # Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
@ -88,9 +78,6 @@ def _load_migrations():
return imports return imports
IMPORTS = _load_migrations()
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name: def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
"""Converts a list of module parts to a `Name` or `Attribute` node.""" """Converts a list of module parts to a `Name` or `Attribute` node."""
if len(module_parts) == 1: if len(module_parts) == 1:
@ -139,23 +126,48 @@ class ImportInfo:
to_import_str: tuple[str, str] to_import_str: tuple[str, str]
IMPORT_INFOS = [ RULE_TO_PATHS = {
"langchain_to_community": ["langchain_to_community.json"],
"langchain_to_core": ["langchain_to_core.json"],
"community_to_core": ["community_to_core.json"],
"langchain_to_text_splitters": ["langchain_to_text_splitters.json"],
"community_to_partner": [
"anthropic.json",
"fireworks.json",
"ibm.json",
"openai.json",
"pinecone.json",
],
}
def generate_import_replacer(rules: List[str]) -> Type[VisitorBasedCodemodCommand]:
"""Generate a codemod to replace imports."""
paths = []
for rule in rules:
if rule not in RULE_TO_PATHS:
raise ValueError(f"Unknown rule: {rule}. Use one of {RULE_TO_PATHS.keys()}")
paths.extend(RULE_TO_PATHS[rule])
imports = _load_migrations(paths)
import_infos = [
ImportInfo( ImportInfo(
import_from=get_import_from_from_str(import_str), import_from=get_import_from_from_str(import_str),
import_str=import_str, import_str=import_str,
to_import_str=to_import_str, to_import_str=to_import_str,
) )
for import_str, to_import_str in IMPORTS.items() for import_str, to_import_str in imports.items()
] ]
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS]) import_match = m.OneOf(*[info.import_from for info in import_infos])
class ReplaceImportsCodemod(VisitorBasedCodemodCommand): class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH) @m.leave(import_match)
def leave_replace_import( def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom: ) -> cst.ImportFrom:
for import_info in IMPORT_INFOS: for import_info in import_infos:
if m.matches(updated_node, import_info.import_from): if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement, # If multiple objects are imported in a single import statement,
@ -169,46 +181,12 @@ class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
for alias in aliases for alias in aliases
if alias.name.value != import_info.to_import_str[-1] if alias.name.value != import_info.to_import_str[-1]
] ]
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT) names[-1] = names[-1].with_changes(
comma=cst.MaybeSentinel.DEFAULT
)
updated_node = updated_node.with_changes(names=names) updated_node = updated_node.with_changes(names=names)
else: else:
return cst.RemoveFromParent() # type: ignore[return-value] return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node return updated_node
return ReplaceImportsCodemod
if __name__ == "__main__":
import textwrap
from rich.console import Console
console = Console()
source = textwrap.dedent(
"""
from pydantic.settings import BaseSettings
from pydantic.color import Color
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
from pydantic import Color
from pydantic import Color as Potato
class Potato(BaseSettings):
color: Color
payment: PaymentCardNumber
brand: PaymentCardBrand
potato: Potato
"""
)
console.print(source)
console.print("=" * 80)
mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceImportsCodemod(context=context)
mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
def generate_raw_migrations( def generate_raw_migrations(
from_package: str, to_package: str from_package: str, to_package: str, filter_by_all: bool = False
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules.""" """Scan the `langchain` package and generate migrations for all modules."""
package = importlib.import_module(from_package) package = importlib.import_module(from_package)
@ -40,11 +40,13 @@ def generate_raw_migrations(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
) )
if not filter_by_all:
# Iterate over all members of the module # Iterate over all members of the module
for name, obj in inspect.getmembers(module): for name, obj in inspect.getmembers(module):
# Check if it's a class or function # Check if it's a class or function
if inspect.isclass(obj) or inspect.isfunction(obj): if inspect.isclass(obj) or inspect.isfunction(obj):
# Check if the module name of the obj starts with 'langchain_community' # Check if the module name of the obj starts with
# 'langchain_community'
if obj.__module__.startswith(to_package): if obj.__module__.startswith(to_package):
items.append( items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
@ -77,21 +79,12 @@ def generate_top_level_imports(pkg: str) -> List[Tuple[str, str]]:
to importing it from the top level namespaces to importing it from the top level namespaces
(e.g., langchain_community.chat_models.XYZ) (e.g., langchain_community.chat_models.XYZ)
""" """
import importlib
package = importlib.import_module(pkg) package = importlib.import_module(pkg)
items = [] items = []
# Only iterate through top-level modules/packages
for finder, modname, ispkg in pkgutil.iter_modules(
package.__path__, package.__name__ + "."
):
if ispkg:
try:
module = importlib.import_module(modname)
except ModuleNotFoundError:
continue
# Function to handle importing from modules
def handle_module(module, module_name):
if hasattr(module, "__all__"): if hasattr(module, "__all__"):
all_objects = getattr(module, "__all__") all_objects = getattr(module, "__all__")
for name in all_objects: for name in all_objects:
@ -102,20 +95,36 @@ def generate_top_level_imports(pkg: str) -> List[Tuple[str, str]]:
original_module = obj.__module__ original_module = obj.__module__
original_name = obj.__name__ original_name = obj.__name__
# Form the new import path from the top-level namespace # Form the new import path from the top-level namespace
top_level_import = f"{modname}.{name}" top_level_import = f"{module_name}.{name}"
# Append the tuple with original and top-level paths # Append the tuple with original and top-level paths
items.append( items.append(
(f"{original_module}.{original_name}", top_level_import) (f"{original_module}.{original_name}", top_level_import)
) )
# Handle the package itself (root level)
handle_module(package, pkg)
# Only iterate through top-level modules/packages
for finder, modname, ispkg in pkgutil.iter_modules(
package.__path__, package.__name__ + "."
):
if ispkg:
try:
module = importlib.import_module(modname)
handle_module(module, modname)
except ModuleNotFoundError:
continue
return items return items
def generate_simplified_migrations( def generate_simplified_migrations(
from_package: str, to_package: str from_package: str, to_package: str, filter_by_all: bool = True
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible.""" """Get all the raw migrations, then simplify them if possible."""
raw_migrations = generate_raw_migrations(from_package, to_package) raw_migrations = generate_raw_migrations(
from_package, to_package, filter_by_all=filter_by_all
)
top_level_simplifications = generate_top_level_imports(to_package) top_level_simplifications = generate_top_level_imports(to_package)
top_level_dict = {full: top_level for full, top_level in top_level_simplifications} top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
simple_migrations = [] simple_migrations = []

View File

@ -8,7 +8,7 @@ import os
import time import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import libcst as cst import libcst as cst
import rich import rich
@ -41,6 +41,9 @@ def main(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern." default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
), ),
log_file: Path = Option("log.txt", help="Log errors to this file."), log_file: Path = Option("log.txt", help="Log errors to this file."),
include_ipynb: bool = Option(
False, help="Include Jupyter Notebook files in the migration."
),
): ):
"""Migrate langchain to the most recent version.""" """Migrate langchain to the most recent version."""
if not diff: if not diff:
@ -63,6 +66,8 @@ def main(
else: else:
package = path package = path
all_files = sorted(package.glob("**/*.py")) all_files = sorted(package.glob("**/*.py"))
if include_ipynb:
all_files.extend(sorted(package.glob("**/*.ipynb")))
filtered_files = [ filtered_files = [
file file
@ -86,11 +91,9 @@ def main(
scratch: dict[str, Any] = {} scratch: dict[str, Any] = {}
start_time = time.time() start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8") log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial( partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff get_and_run_codemods, disable, metadata_manager, scratch, package, diff
) )
with Progress(*Progress.get_default_columns(), transient=True) as progress: with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files)) task = progress.add_task(description="Executing codemods...", total=len(files))
@ -127,24 +130,28 @@ def main(
raise Exit(1) raise Exit(1)
def run_codemods( def get_and_run_codemods(
codemods: List[Type[ContextAwareTransformer]], disabled_rules: List[Rule],
metadata_manager: FullRepoManager, metadata_manager: FullRepoManager,
scratch: Dict[str, Any], scratch: Dict[str, Any],
package: Path, package: Path,
diff: bool, diff: bool,
filename: str, filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]: ) -> Tuple[Union[str, None], Union[List[str], None]]:
try: """Run codemods from rules.
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)
Wrapper around run_codemods to be used with multiprocessing.Pool.
"""
codemods = gather_codemods(disabled=disabled_rules)
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename)
def _rewrite_file(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Union[str, None], Union[List[str], None]]:
file_path = Path(filename) file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp: with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read() code = fp.read()
@ -171,6 +178,95 @@ def run_codemods(
fp.write(output_code) fp.write(output_code)
fp.truncate() fp.truncate()
return None, None return None, None
def _rewrite_notebook(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Optional[str], Optional[List[str]]]:
"""Try to rewrite a Jupyter Notebook file."""
import nbformat
file_path = Path(filename)
if file_path.suffix != ".ipynb":
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.")
with file_path.open("r", encoding="utf-8") as fp:
notebook = nbformat.read(fp, as_version=4)
diffs = []
for cell in notebook.cells:
if cell.cell_type == "code":
code = "".join(cell.source)
# Skip code if any of the lines begin with a magic command or
# a ! command.
# We can try to handle later.
if any(
line.startswith("!") or line.startswith("%")
for line in code.splitlines()
):
continue
input_tree = cst.parse_module(code)
# TODO(Team): Quick hack, need to figure out
# how to handle this correctly.
# This prevents the code from trying to re-insert the imports
# for every cell in the notebook.
local_context = CodemodContext()
for codemod in codemods:
transformer = codemod(context=local_context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
cell.source = output_code.splitlines(keepends=True)
if diff:
cell_diff = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
diffs.extend(list(cell_diff))
if diff:
return None, diffs
with file_path.open("w", encoding="utf-8") as fp:
nbformat.write(notebook, fp)
return None, None
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)
if filename.endswith(".ipynb"):
return _rewrite_notebook(filename, codemods, diff, context)
else:
return _rewrite_file(filename, codemods, diff, context)
except cst.ParserSyntaxError as exc: except cst.ParserSyntaxError as exc:
return ( return (
f"A syntax error happened on {filename}. This file cannot be " f"A syntax error happened on {filename}. This file cannot be "

View File

@ -32,10 +32,15 @@ def cli():
default=None, default=None,
help="Output file for the migration script.", help="Output file for the migration script.",
) )
def generic(pkg1: str, pkg2: str, output: str) -> None: @click.option(
"--filter-by-all/--no-filter-by-all",
default=True,
help="Output file for the migration script.",
)
def generic(pkg1: str, pkg2: str, output: str, filter_by_all: bool) -> None:
"""Generate a migration script.""" """Generate a migration script."""
click.echo("Migration script generated.") click.echo("Migration script generated.")
migrations = generate_simplified_migrations(pkg1, pkg2) migrations = generate_simplified_migrations(pkg1, pkg2, filter_by_all=filter_by_all)
if output is None: if output is None:
output = f"{pkg1}_to_{pkg2}.json" output = f"{pkg1}_to_{pkg2}.json"

View File

@ -1,16 +1,22 @@
from langchain._api import suppress_langchain_deprecation_warning as sup2
from langchain_core._api import suppress_langchain_deprecation_warning as sup1
from langchain_cli.namespaces.migrate.generate.generic import ( from langchain_cli.namespaces.migrate.generate.generic import (
generate_simplified_migrations, generate_simplified_migrations,
generate_raw_migrations,
) )
def test_create_json_agent_migration() -> None: def test_create_json_agent_migration() -> None:
"""Test the migration of create_json_agent from langchain to langchain_community.""" """Test the migration of create_json_agent from langchain to langchain_community."""
with sup1():
with sup2():
raw_migrations = generate_simplified_migrations( raw_migrations = generate_simplified_migrations(
from_package="langchain", to_package="langchain_community" from_package="langchain", to_package="langchain_community"
) )
json_agent_migrations = [ json_agent_migrations = [
migration for migration in raw_migrations if "create_json_agent" in migration[0] migration
for migration in raw_migrations
if "create_json_agent" in migration[0]
] ]
assert json_agent_migrations == [ assert json_agent_migrations == [
( (
@ -30,12 +36,16 @@ def test_create_json_agent_migration() -> None:
def test_create_single_store_retriever_db() -> None: def test_create_single_store_retriever_db() -> None:
"""Test migration from langchain to langchain_core""" """Test migration from langchain to langchain_core"""
with sup1():
with sup2():
raw_migrations = generate_simplified_migrations( raw_migrations = generate_simplified_migrations(
from_package="langchain", to_package="langchain_core" from_package="langchain", to_package="langchain_core"
) )
# SingleStore was an old name for VectorStoreRetriever # SingleStore was an old name for VectorStoreRetriever
single_store_migration = [ single_store_migration = [
migration for migration in raw_migrations if "SingleStore" in migration[0] migration
for migration in raw_migrations
if "SingleStore" in migration[0]
] ]
assert single_store_migration == [ assert single_store_migration == [
( (

View File

@ -7,9 +7,18 @@ pytest.importorskip("libcst")
from libcst.codemod import CodemodTest from libcst.codemod import CodemodTest
from langchain_cli.namespaces.migrate.codemods.replace_imports import ( from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod, generate_import_replacer,
) )
ReplaceImportsCodemod = generate_import_replacer(
[
"langchain_to_community",
"community_to_partner",
"langchain_to_core",
"community_to_core",
]
) # type: ignore[attr-defined] # noqa: E501
class TestReplaceImportsCommand(CodemodTest): class TestReplaceImportsCommand(CodemodTest):
TRANSFORM = ReplaceImportsCodemod TRANSFORM = ReplaceImportsCodemod