From d82cbf5e764c6e0a9d0bee2558119d302b0b57ad Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 20 Nov 2023 16:09:30 -0500 Subject: [PATCH] Separate out langchain_core package (#13577) Co-authored-by: Nuno Campos Co-authored-by: Bagatur Co-authored-by: Erick Friis --- .../workflows/_compile_integration_test.yml | 12 + .github/workflows/_lint.yml | 14 +- .github/workflows/_pydantic_compatibility.yml | 24 + .github/workflows/_test.yml | 27 +- .github/workflows/langchain_ci.yml | 9 + .github/workflows/langchain_core_ci.yml | 52 + .github/workflows/langchain_core_release.yml | 13 + .../workflows/langchain_experimental_ci.yml | 4 + docs/api_reference/create_api_rst.py | 17 +- .../themes/scikit-learn-modern/nav.html | 3 + libs/core/Makefile | 54 + libs/core/README.md | 1 + libs/core/langchain_core/__init__.py | 7 + libs/core/langchain_core/_api/__init__.py | 26 + libs/core/langchain_core/_api/deprecation.py | 341 ++ libs/core/langchain_core/_api/path.py | 36 + .../langchain_core/callbacks}/__init__.py | 0 libs/core/langchain_core/callbacks/base.py | 598 ++++ libs/core/langchain_core/callbacks/manager.py | 2075 +++++++++++ libs/core/langchain_core/callbacks/stdout.py | 97 + .../callbacks/streaming_stdout.py | 67 + .../callbacks/tracers}/__init__.py | 0 .../langchain_core/callbacks/tracers/base.py | 537 +++ .../callbacks/tracers/evaluation.py | 223 ++ .../callbacks/tracers/langchain.py | 262 ++ .../callbacks/tracers/langchain_v1.py | 185 + .../callbacks/tracers/log_stream.py | 313 ++ .../callbacks/tracers/root_listeners.py | 54 + .../callbacks/tracers/run_collector.py | 52 + .../callbacks/tracers/schemas.py | 140 + .../callbacks/tracers/stdout.py | 178 + libs/core/langchain_core/chat_model.py | 735 ++++ libs/core/langchain_core/env.py | 17 + libs/core/langchain_core/globals/__init__.py | 197 ++ libs/core/langchain_core/llm.py | 1077 ++++++ libs/core/langchain_core/load/__init__.py | 6 + libs/core/langchain_core/load/dump.py | 26 + libs/core/langchain_core/load/load.py | 130 + libs/core/langchain_core/load/serializable.py | 207 ++ .../langchain_core/output_parsers/__init__.py | 0 .../langchain_core/output_parsers/list.py | 79 + libs/core/langchain_core/prompts/__init__.py | 75 + libs/core/langchain_core/prompts/base.py | 173 + libs/core/langchain_core/prompts/chat.py | 748 ++++ .../prompts/example_selector/__init__.py | 14 + .../prompts/example_selector/base.py | 15 + .../prompts/example_selector/length_based.py | 63 + .../example_selector/semantic_similarity.py | 165 + libs/core/langchain_core/prompts/few_shot.py | 343 ++ .../prompts/few_shot_with_templates.py | 153 + libs/core/langchain_core/prompts/loading.py | 162 + libs/core/langchain_core/prompts/pipeline.py | 56 + libs/core/langchain_core/prompts/prompt.py | 250 ++ .../langchain_core/pydantic_v1/__init__.py | 23 + .../langchain_core/pydantic_v1/dataclasses.py | 4 + libs/core/langchain_core/pydantic_v1/main.py | 4 + .../core/langchain_core/runnables/__init__.py | 57 + libs/core/langchain_core/runnables/base.py | 3026 ++++++++++++++++ libs/core/langchain_core/runnables/branch.py | 254 ++ libs/core/langchain_core/runnables/config.py | 401 +++ .../langchain_core/runnables/configurable.py | 388 +++ .../langchain_core/runnables/fallbacks.py | 344 ++ libs/core/langchain_core/runnables/history.py | 288 ++ .../langchain_core/runnables/passthrough.py | 453 +++ libs/core/langchain_core/runnables/retry.py | 337 ++ libs/core/langchain_core/runnables/router.py | 206 ++ libs/core/langchain_core/runnables/utils.py | 327 ++ libs/core/langchain_core/schema/__init__.py | 78 + libs/core/langchain_core/schema/agent.py | 74 + libs/core/langchain_core/schema/cache.py | 24 + libs/core/langchain_core/schema/chat.py | 13 + .../langchain_core/schema/chat_history.py | 67 + libs/core/langchain_core/schema/document.py | 91 + libs/core/langchain_core/schema/embeddings.py | 27 + libs/core/langchain_core/schema/exceptions.py | 2 + .../langchain_core/schema/language_model.py | 291 ++ libs/core/langchain_core/schema/memory.py | 59 + libs/core/langchain_core/schema/messages.py | 415 +++ libs/core/langchain_core/schema/output.py | 175 + .../langchain_core/schema/output_parser.py | 475 +++ libs/core/langchain_core/schema/prompt.py | 28 + .../langchain_core/schema/prompt_template.py | 228 ++ libs/core/langchain_core/schema/retriever.py | 275 ++ libs/core/langchain_core/schema/storage.py | 53 + .../core/langchain_core/schema/vectorstore.py | 702 ++++ libs/core/langchain_core/tool.py | 845 +++++ libs/core/langchain_core/utils/__init__.py | 38 + libs/core/langchain_core/utils/aiter.py | 209 ++ libs/core/langchain_core/utils/formatting.py | 38 + libs/core/langchain_core/utils/input.py | 42 + libs/core/langchain_core/utils/iter.py | 175 + libs/core/langchain_core/utils/loading.py | 54 + libs/core/langchain_core/utils/pydantic.py | 14 + libs/core/langchain_core/utils/utils.py | 180 + libs/core/poetry.lock | 2689 +++++++++++++++ libs/core/pyproject.toml | 85 + libs/core/tests/__init__.py | 0 libs/core/tests/unit_tests/__init__.py | 0 libs/core/tests/unit_tests/_api/__init__.py | 0 .../tests/unit_tests/_api/test_deprecation.py | 4 +- .../tests/unit_tests/_api/test_imports.py | 2 +- .../tests/unit_tests/_api/test_path.py | 4 +- .../tests/unit_tests/data/prompt_file.txt | 2 + .../data/prompts/prompt_extra_args.json | 5 + .../data/prompts/prompt_missing_args.json | 3 + .../data/prompts/simple_prompt.json | 4 + .../unit_tests/examples/example-non-utf8.csv | 11 + .../unit_tests/examples/example-non-utf8.txt | 1 + .../unit_tests/examples/example-utf8.csv | 11 + .../unit_tests/examples/example-utf8.txt | 6 + .../unit_tests/examples/example_prompt.json | 0 .../tests/unit_tests/examples/examples.json | 0 .../tests/unit_tests/examples/examples.yaml | 0 .../unit_tests/examples/few_shot_prompt.json | 0 .../unit_tests/examples/few_shot_prompt.yaml | 0 .../few_shot_prompt_example_prompt.json | 0 .../examples/few_shot_prompt_examples_in.json | 0 .../few_shot_prompt_yaml_examples.yaml | 0 .../examples/jinja_injection_prompt.json | 0 .../examples/jinja_injection_prompt.yaml | 0 .../examples/prompt_with_output_parser.json | 0 .../unit_tests/examples/simple_prompt.json | 0 .../unit_tests/examples/simple_prompt.yaml | 0 .../simple_prompt_with_template_file.json | 0 .../unit_tests/examples/simple_template.txt | 0 libs/core/tests/unit_tests/fake/__init__.py | 0 libs/core/tests/unit_tests/fake/callbacks.py | 391 +++ libs/core/tests/unit_tests/fake/chat_model.py | 105 + libs/core/tests/unit_tests/fake/llm.py | 90 + libs/core/tests/unit_tests/fake/memory.py | 23 + libs/core/tests/unit_tests/prompt_file.txt | 2 + .../tests/unit_tests/prompts/__init__.py | 0 .../unit_tests/prompts/prompt_extra_args.json | 5 + .../prompts/prompt_missing_args.json | 3 + .../unit_tests/prompts/simple_prompt.json | 4 + .../tests/unit_tests/prompts/test_chat.py | 6 +- .../tests/unit_tests/prompts/test_few_shot.py | 12 +- .../prompts/test_few_shot_with_templates.py | 4 +- .../tests/unit_tests/prompts/test_imports.py | 3 +- .../test_length_based_example_selector.py | 6 +- .../tests/unit_tests/prompts/test_loading.py | 22 +- .../prompts/test_pipeline_prompt.py | 6 +- .../tests/unit_tests/prompts/test_prompt.py | 2 +- .../tests/unit_tests/prompts/test_utils.py | 2 +- .../tests/unit_tests/runnable/__init__.py | 0 .../runnable/__snapshots__/test_runnable.ambr | 556 ++- .../tests/unit_tests}/runnable/test_config.py | 10 +- .../unit_tests}/runnable/test_history.py | 10 +- .../unit_tests}/runnable/test_runnable.py | 506 +-- .../tests/unit_tests}/runnable/test_utils.py | 2 +- libs/core/tests/unit_tests/schema/__init__.py | 0 .../tests/unit_tests/schema/test_imports.py | 43 + .../tests/unit_tests/schema/test_messages.py | 102 + .../tests/unit_tests/schema/test_output.py | 60 + libs/core/tests/unit_tests/test_globals.py | 31 + .../tests/unit_tests/test_tool.py} | 16 +- libs/core/tests/unit_tests/utils/__init__.py | 0 .../tests/unit_tests/utils/test_imports.py | 21 + libs/core/tests/unit_tests/utils/test_iter.py | 21 + libs/langchain/Makefile | 5 +- libs/langchain/langchain/__init__.py | 10 +- libs/langchain/langchain/_api/deprecation.py | 356 +- libs/langchain/langchain/_api/path.py | 37 +- libs/langchain/langchain/adapters/openai.py | 7 +- libs/langchain/langchain/agents/__init__.py | 3 +- libs/langchain/langchain/agents/agent.py | 28 +- .../langchain/agents/agent_iterator.py | 7 +- .../agents/agent_toolkits/__init__.py | 3 +- .../agent_toolkits/ainetwork/toolkit.py | 3 +- .../agents/agent_toolkits/amadeus/toolkit.py | 3 +- .../langchain/agents/agent_toolkits/base.py | 3 +- .../openai_functions.py | 9 +- .../agents/agent_toolkits/csv/__init__.py | 2 +- .../agent_toolkits/file_management/toolkit.py | 3 +- .../agents/agent_toolkits/gmail/toolkit.py | 3 +- .../agents/agent_toolkits/json/base.py | 3 +- .../agents/agent_toolkits/nla/tool.py | 3 +- .../agents/agent_toolkits/nla/toolkit.py | 5 +- .../agent_toolkits/office365/toolkit.py | 3 +- .../agents/agent_toolkits/openapi/base.py | 3 +- .../agents/agent_toolkits/openapi/planner.py | 8 +- .../agent_toolkits/openapi/planner_prompt.py | 2 +- .../agents/agent_toolkits/openapi/toolkit.py | 3 +- .../agents/agent_toolkits/pandas/__init__.py | 2 +- .../agent_toolkits/playwright/toolkit.py | 3 +- .../agents/agent_toolkits/powerbi/base.py | 3 +- .../agents/agent_toolkits/powerbi/toolkit.py | 17 +- .../agents/agent_toolkits/python/__init__.py | 2 +- .../agents/agent_toolkits/spark/__init__.py | 2 +- .../agents/agent_toolkits/spark_sql/base.py | 3 +- .../agent_toolkits/spark_sql/toolkit.py | 5 +- .../agents/agent_toolkits/sql/base.py | 15 +- .../agents/agent_toolkits/sql/toolkit.py | 5 +- .../agents/agent_toolkits/vectorstore/base.py | 3 +- .../agent_toolkits/vectorstore/toolkit.py | 7 +- .../agents/agent_toolkits/xorbits/__init__.py | 2 +- .../agents/agent_toolkits/zapier/toolkit.py | 3 +- libs/langchain/langchain/agents/chat/base.py | 17 +- .../langchain/agents/chat/output_parser.py | 3 +- .../langchain/agents/conversational/base.py | 7 +- .../agents/conversational/output_parser.py | 3 +- .../agents/conversational_chat/base.py | 21 +- .../conversational_chat/output_parser.py | 3 +- .../langchain/agents/format_scratchpad/log.py | 2 +- .../format_scratchpad/log_to_messages.py | 4 +- .../format_scratchpad/openai_functions.py | 4 +- .../agents/format_scratchpad/openai_tools.py | 7 +- .../langchain/agents/format_scratchpad/xml.py | 2 +- libs/langchain/langchain/agents/initialize.py | 3 +- libs/langchain/langchain/agents/load_tools.py | 2 +- libs/langchain/langchain/agents/loading.py | 2 +- libs/langchain/langchain/agents/mrkl/base.py | 7 +- .../langchain/agents/mrkl/output_parser.py | 3 +- .../langchain/agents/openai_assistant/base.py | 11 +- .../agent_token_buffer_memory.py | 5 +- .../agents/openai_functions_agent/base.py | 35 +- .../openai_functions_multi_agent/base.py | 41 +- .../langchain/agents/output_parsers/json.py | 3 +- .../agents/output_parsers/openai_functions.py | 11 +- .../agents/output_parsers/openai_tools.py | 11 +- .../output_parsers/react_json_single_input.py | 3 +- .../output_parsers/react_single_input.py | 3 +- .../agents/output_parsers/self_ask.py | 3 +- .../langchain/agents/output_parsers/xml.py | 3 +- libs/langchain/langchain/agents/react/base.py | 7 +- .../langchain/agents/react/output_parser.py | 3 +- .../agents/react/textworld_prompt.py | 2 +- .../langchain/agents/react/wiki_prompt.py | 2 +- libs/langchain/langchain/agents/schema.py | 4 +- .../agents/self_ask_with_search/base.py | 7 +- .../agents/self_ask_with_search/prompt.py | 2 +- .../langchain/agents/structured_chat/base.py | 17 +- .../agents/structured_chat/output_parser.py | 9 +- libs/langchain/langchain/agents/xml/base.py | 5 +- libs/langchain/langchain/base_language.py | 2 +- libs/langchain/langchain/cache.py | 11 +- .../langchain/callbacks/aim_callback.py | 3 +- .../langchain/callbacks/argilla_callback.py | 2 +- .../langchain/callbacks/arize_callback.py | 3 +- .../langchain/callbacks/arthur_callback.py | 2 +- libs/langchain/langchain/callbacks/base.py | 2 +- .../langchain/callbacks/clearml_callback.py | 3 +- .../langchain/callbacks/comet_ml_callback.py | 3 +- .../langchain/callbacks/confident_callback.py | 2 +- .../langchain/callbacks/context_callback.py | 5 +- libs/langchain/langchain/callbacks/file.py | 5 +- .../langchain/callbacks/flyte_callback.py | 3 +- .../langchain/callbacks/infino_callback.py | 5 +- .../callbacks/labelstudio_callback.py | 5 +- .../langchain/callbacks/llmonitor_callback.py | 6 +- libs/langchain/langchain/callbacks/manager.py | 7 +- .../langchain/callbacks/mlflow_callback.py | 3 +- .../langchain/callbacks/openai_info.py | 3 +- .../callbacks/promptlayer_callback.py | 7 +- .../langchain/callbacks/sagemaker_callback.py | 3 +- libs/langchain/langchain/callbacks/stdout.py | 2 +- .../langchain/callbacks/streaming_aiter.py | 3 +- .../callbacks/streaming_aiter_final_only.py | 3 +- .../langchain/callbacks/streaming_stdout.py | 2 +- .../streamlit/streamlit_callback_handler.py | 3 +- .../langchain/callbacks/tracers/__init__.py | 9 +- .../langchain/callbacks/tracers/base.py | 2 +- .../langchain/callbacks/tracers/evaluation.py | 2 +- .../langchain/callbacks/tracers/langchain.py | 2 +- .../callbacks/tracers/langchain_v1.py | 2 +- .../langchain/callbacks/tracers/log_stream.py | 2 +- .../callbacks/tracers/root_listeners.py | 2 +- .../callbacks/tracers/run_collector.py | 2 +- .../langchain/callbacks/tracers/schemas.py | 2 +- .../langchain/callbacks/tracers/stdout.py | 2 +- .../langchain/callbacks/trubrics_callback.py | 7 +- .../langchain/callbacks/wandb_callback.py | 3 +- libs/langchain/langchain/chains/api/base.py | 7 +- .../langchain/chains/api/openapi/chain.py | 4 +- .../chains/api/openapi/requests_chain.py | 7 +- .../chains/api/openapi/response_chain.py | 7 +- libs/langchain/langchain/chains/api/prompt.py | 2 +- libs/langchain/langchain/chains/base.py | 20 +- .../chains/chat_vector_db/prompts.py | 2 +- .../chains/combine_documents/base.py | 5 +- .../chains/combine_documents/map_reduce.py | 7 +- .../chains/combine_documents/map_rerank.py | 7 +- .../chains/combine_documents/reduce.py | 5 +- .../chains/combine_documents/refine.py | 9 +- .../chains/combine_documents/stuff.py | 9 +- .../chains/constitutional_ai/base.py | 5 +- .../chains/constitutional_ai/models.py | 2 +- .../chains/constitutional_ai/prompts.py | 4 +- .../langchain/chains/conversation/base.py | 5 +- .../langchain/chains/conversation/prompt.py | 2 +- .../chains/conversational_retrieval/base.py | 15 +- .../conversational_retrieval/prompts.py | 2 +- .../chains/elasticsearch_database/base.py | 7 +- .../chains/elasticsearch_database/prompts.py | 2 +- .../langchain/chains/example_generator.py | 7 +- libs/langchain/langchain/chains/flare/base.py | 6 +- .../langchain/chains/flare/prompts.py | 4 +- .../langchain/chains/graph_qa/arangodb.py | 5 +- .../langchain/chains/graph_qa/base.py | 7 +- .../langchain/chains/graph_qa/cypher.py | 7 +- .../langchain/chains/graph_qa/falkordb.py | 5 +- .../langchain/chains/graph_qa/hugegraph.py | 7 +- .../langchain/chains/graph_qa/kuzu.py | 7 +- .../langchain/chains/graph_qa/nebulagraph.py | 7 +- .../chains/graph_qa/neptune_cypher.py | 5 +- .../langchain/chains/graph_qa/prompts.py | 2 +- .../langchain/chains/graph_qa/sparql.py | 7 +- libs/langchain/langchain/chains/hyde/base.py | 6 +- .../langchain/chains/hyde/prompts.py | 2 +- libs/langchain/langchain/chains/llm.py | 41 +- .../langchain/chains/llm_checker/base.py | 7 +- .../langchain/chains/llm_checker/prompt.py | 2 +- .../langchain/chains/llm_math/base.py | 7 +- .../langchain/chains/llm_math/prompt.py | 2 +- .../langchain/chains/llm_requests.py | 3 +- .../chains/llm_summarization_checker/base.py | 7 +- libs/langchain/langchain/chains/loading.py | 10 +- libs/langchain/langchain/chains/mapreduce.py | 7 +- libs/langchain/langchain/chains/moderation.py | 3 +- .../langchain/langchain/chains/natbot/base.py | 5 +- .../langchain/chains/natbot/prompt.py | 2 +- .../langchain/chains/openai_functions/base.py | 30 +- .../openai_functions/citation_fuzzy_match.py | 9 +- .../chains/openai_functions/extraction.py | 9 +- .../chains/openai_functions/openapi.py | 8 +- .../openai_functions/qa_with_structure.py | 13 +- .../chains/openai_functions/tagging.py | 5 +- .../chains/openai_tools/extraction.py | 9 +- .../langchain/chains/prompt_selector.py | 7 +- .../langchain/chains/qa_generation/base.py | 7 +- .../langchain/chains/qa_generation/prompt.py | 4 +- .../langchain/chains/qa_with_sources/base.py | 7 +- .../chains/qa_with_sources/loading.py | 5 +- .../qa_with_sources/map_reduce_prompt.py | 2 +- .../chains/qa_with_sources/refine_prompts.py | 2 +- .../chains/qa_with_sources/retrieval.py | 5 +- .../chains/qa_with_sources/stuff_prompt.py | 2 +- .../chains/qa_with_sources/vector_db.py | 5 +- .../chains/query_constructor/base.py | 13 +- .../langchain/chains/query_constructor/ir.py | 2 +- .../chains/query_constructor/parser.py | 3 +- .../chains/query_constructor/prompt.py | 2 +- .../chains/query_constructor/schema.py | 2 +- .../chains/question_answering/__init__.py | 5 +- .../question_answering/map_reduce_prompt.py | 4 +- .../question_answering/map_rerank_prompt.py | 2 +- .../question_answering/refine_prompts.py | 4 +- .../chains/question_answering/stuff_prompt.py | 4 +- .../langchain/chains/retrieval_qa/base.py | 13 +- .../langchain/chains/retrieval_qa/prompt.py | 2 +- .../langchain/langchain/chains/router/base.py | 3 +- .../chains/router/embedding_router.py | 7 +- .../langchain/chains/router/llm_router.py | 11 +- .../langchain/chains/router/multi_prompt.py | 5 +- .../chains/router/multi_retrieval_qa.py | 7 +- libs/langchain/langchain/chains/sequential.py | 5 +- .../langchain/chains/sql_database/prompt.py | 2 +- .../langchain/chains/sql_database/query.py | 9 +- .../langchain/chains/summarize/__init__.py | 5 +- .../chains/summarize/map_reduce_prompt.py | 2 +- .../chains/summarize/refine_prompts.py | 2 +- .../chains/summarize/stuff_prompt.py | 2 +- libs/langchain/langchain/chains/transform.py | 3 +- libs/langchain/langchain/chat_loaders/base.py | 2 +- .../chat_loaders/facebook_messenger.py | 5 +- .../langchain/langchain/chat_loaders/gmail.py | 5 +- .../langchain/chat_loaders/imessage.py | 5 +- .../langchain/chat_loaders/langsmith.py | 5 +- .../langchain/langchain/chat_loaders/slack.py | 5 +- .../langchain/chat_loaders/telegram.py | 5 +- .../langchain/langchain/chat_loaders/utils.py | 4 +- .../langchain/chat_loaders/whatsapp.py | 5 +- .../langchain/chat_models/anthropic.py | 21 +- .../langchain/chat_models/anyscale.py | 7 +- .../langchain/chat_models/azure_openai.py | 5 +- .../langchain/chat_models/azureml_endpoint.py | 14 +- .../langchain/chat_models/baichuan.py | 18 +- .../chat_models/baidu_qianfan_endpoint.py | 19 +- libs/langchain/langchain/chat_models/base.py | 744 +--- .../langchain/chat_models/bedrock.py | 7 +- .../langchain/langchain/chat_models/cohere.py | 21 +- libs/langchain/langchain/chat_models/ernie.py | 10 +- .../langchain/chat_models/everlyai.py | 5 +- libs/langchain/langchain/chat_models/fake.py | 7 +- .../langchain/chat_models/fireworks.py | 23 +- .../langchain/chat_models/gigachat.py | 21 +- .../langchain/chat_models/google_palm.py | 24 +- libs/langchain/langchain/chat_models/human.py | 16 +- .../langchain/chat_models/hunyuan.py | 18 +- .../chat_models/javelin_ai_gateway.py | 17 +- .../langchain/chat_models/jinachat.py | 41 +- libs/langchain/langchain/chat_models/konko.py | 8 +- .../langchain/chat_models/litellm.py | 29 +- .../langchain/chat_models/minimax.py | 13 +- .../chat_models/mlflow_ai_gateway.py | 17 +- .../langchain/langchain/chat_models/ollama.py | 17 +- .../langchain/langchain/chat_models/openai.py | 39 +- .../langchain/chat_models/pai_eas_endpoint.py | 22 +- .../chat_models/promptlayer_openai.py | 5 +- .../langchain/langchain/chat_models/tongyi.py | 34 +- .../langchain/chat_models/vertexai.py | 21 +- .../langchain/langchain/chat_models/yandex.py | 17 +- .../langchain/docstore/arbitrary_fn.py | 3 +- libs/langchain/langchain/docstore/document.py | 2 +- .../langchain/document_loaders/airbyte.py | 3 +- .../document_loaders/apify_dataset.py | 5 +- .../langchain/document_loaders/base.py | 3 +- .../langchain/document_loaders/base_o365.py | 9 +- .../document_loaders/blob_loaders/schema.py | 2 +- .../langchain/document_loaders/concurrent.py | 3 +- .../langchain/document_loaders/docugami.py | 2 +- .../langchain/document_loaders/dropbox.py | 3 +- .../langchain/document_loaders/embaas.py | 2 +- .../langchain/document_loaders/generic.py | 3 +- .../langchain/document_loaders/github.py | 2 +- .../langchain/document_loaders/googledrive.py | 3 +- .../langchain/document_loaders/joplin.py | 3 +- .../langchain/document_loaders/lakefs.py | 2 +- .../langchain/document_loaders/onedrive.py | 3 +- .../document_loaders/onedrive_file.py | 3 +- .../document_loaders/parsers/audio.py | 3 +- .../document_loaders/parsers/docai.py | 3 +- .../document_loaders/parsers/generic.py | 3 +- .../document_loaders/parsers/msword.py | 3 +- .../langchain/document_loaders/parsers/pdf.py | 2 +- .../langchain/document_loaders/parsers/txt.py | 3 +- .../langchain/document_loaders/rocksetdb.py | 3 +- .../langchain/document_loaders/sharepoint.py | 3 +- .../langchain/document_loaders/sitemap.py | 3 +- .../document_loaders/tensorflow_datasets.py | 3 +- .../langchain/document_loaders/youtube.py | 5 +- .../beautiful_soup_transformer.py | 2 +- .../doctran_text_extract.py | 3 +- .../document_transformers/doctran_text_qa.py | 3 +- .../doctran_text_translate.py | 3 +- .../embeddings_redundant_filter.py | 6 +- .../document_transformers/google_translate.py | 3 +- .../document_transformers/html2text.py | 2 +- .../long_context_reorder.py | 4 +- .../nuclia_text_transform.py | 3 +- .../document_transformers/openai_functions.py | 13 +- .../langchain/embeddings/aleph_alpha.py | 5 +- libs/langchain/langchain/embeddings/awa.py | 4 +- .../langchain/embeddings/azure_openai.py | 3 +- .../embeddings/baidu_qianfan_endpoint.py | 5 +- libs/langchain/langchain/embeddings/base.py | 2 +- .../langchain/langchain/embeddings/bedrock.py | 4 +- libs/langchain/langchain/embeddings/cache.py | 5 +- .../langchain/embeddings/clarifai.py | 5 +- libs/langchain/langchain/embeddings/cohere.py | 5 +- .../langchain/embeddings/dashscope.py | 4 +- .../langchain/embeddings/deepinfra.py | 4 +- libs/langchain/langchain/embeddings/edenai.py | 5 +- .../langchain/embeddings/elasticsearch.py | 2 +- libs/langchain/langchain/embeddings/embaas.py | 4 +- libs/langchain/langchain/embeddings/ernie.py | 4 +- libs/langchain/langchain/embeddings/fake.py | 5 +- .../langchain/embeddings/fastembed.py | 5 +- .../langchain/embeddings/google_palm.py | 4 +- .../langchain/langchain/embeddings/gpt4all.py | 4 +- .../langchain/embeddings/gradient_ai.py | 4 +- .../langchain/embeddings/huggingface.py | 5 +- .../langchain/embeddings/huggingface_hub.py | 5 +- .../embeddings/javelin_ai_gateway.py | 4 +- libs/langchain/langchain/embeddings/jina.py | 4 +- .../langchain/embeddings/johnsnowlabs.py | 3 +- .../langchain/embeddings/llamacpp.py | 4 +- .../langchain/embeddings/llm_rails.py | 5 +- .../langchain/langchain/embeddings/localai.py | 7 +- .../langchain/langchain/embeddings/minimax.py | 4 +- .../langchain/embeddings/mlflow_gateway.py | 4 +- .../langchain/embeddings/modelscope_hub.py | 4 +- .../langchain/embeddings/mosaicml.py | 4 +- .../langchain/embeddings/nlpcloud.py | 5 +- .../langchain/embeddings/octoai_embeddings.py | 5 +- libs/langchain/langchain/embeddings/ollama.py | 5 +- libs/langchain/langchain/embeddings/openai.py | 7 +- .../embeddings/sagemaker_endpoint.py | 5 +- .../langchain/embeddings/self_hosted.py | 5 +- .../langchain/embeddings/spacy_embeddings.py | 4 +- .../langchain/embeddings/tensorflow_hub.py | 4 +- .../langchain/embeddings/vertexai.py | 5 +- .../langchain/embeddings/voyageai.py | 7 +- .../langchain/embeddings/xinference.py | 2 +- .../agents/trajectory_eval_chain.py | 7 +- .../agents/trajectory_eval_prompt.py | 4 +- .../evaluation/comparison/eval_chain.py | 9 +- .../langchain/evaluation/comparison/prompt.py | 2 +- .../evaluation/criteria/eval_chain.py | 7 +- .../langchain/evaluation/criteria/prompt.py | 2 +- .../evaluation/embedding_distance/base.py | 6 +- .../langchain/langchain/evaluation/loading.py | 3 +- .../langchain/evaluation/qa/eval_chain.py | 9 +- .../langchain/evaluation/qa/eval_prompt.py | 2 +- .../langchain/evaluation/qa/generate_chain.py | 7 +- .../evaluation/qa/generate_prompt.py | 2 +- libs/langchain/langchain/evaluation/schema.py | 5 +- .../evaluation/scoring/eval_chain.py | 9 +- .../langchain/evaluation/scoring/prompt.py | 2 +- .../evaluation/string_distance/base.py | 5 +- libs/langchain/langchain/formatting.py | 2 +- libs/langchain/langchain/globals/__init__.py | 2 +- .../langchain/graphs/graph_document.py | 6 +- libs/langchain/langchain/hub.py | 4 +- libs/langchain/langchain/indexes/_api.py | 7 +- libs/langchain/langchain/indexes/graph.py | 7 +- .../indexes/prompts/entity_extraction.py | 2 +- .../indexes/prompts/entity_summarization.py | 2 +- .../prompts/knowledge_triplet_extraction.py | 2 +- .../langchain/indexes/vectorstore.py | 11 +- libs/langchain/langchain/input.py | 2 +- libs/langchain/langchain/llms/ai21.py | 5 +- libs/langchain/langchain/llms/aleph_alpha.py | 6 +- .../langchain/llms/amazon_api_gateway.py | 2 +- libs/langchain/langchain/llms/anthropic.py | 21 +- libs/langchain/langchain/llms/anyscale.py | 10 +- libs/langchain/langchain/llms/arcee.py | 3 +- libs/langchain/langchain/llms/aviary.py | 2 +- .../langchain/llms/azureml_endpoint.py | 3 +- .../langchain/llms/baidu_qianfan_endpoint.py | 5 +- libs/langchain/langchain/llms/bananadev.py | 3 +- libs/langchain/langchain/llms/base.py | 1088 +----- libs/langchain/langchain/llms/baseten.py | 3 +- libs/langchain/langchain/llms/beam.py | 2 +- libs/langchain/langchain/llms/bedrock.py | 5 +- libs/langchain/langchain/llms/cerebriumai.py | 3 +- libs/langchain/langchain/llms/clarifai.py | 5 +- libs/langchain/langchain/llms/cohere.py | 4 +- .../langchain/langchain/llms/ctransformers.py | 3 +- libs/langchain/langchain/llms/ctranslate2.py | 5 +- libs/langchain/langchain/llms/databricks.py | 8 +- libs/langchain/langchain/llms/deepinfra.py | 5 +- libs/langchain/langchain/llms/deepsparse.py | 4 +- libs/langchain/langchain/llms/edenai.py | 2 +- libs/langchain/langchain/llms/fake.py | 5 +- libs/langchain/langchain/llms/fireworks.py | 7 +- libs/langchain/langchain/llms/forefrontai.py | 2 +- libs/langchain/langchain/llms/gigachat.py | 7 +- libs/langchain/langchain/llms/google_palm.py | 4 +- libs/langchain/langchain/llms/gooseai.py | 6 +- libs/langchain/langchain/llms/gpt4all.py | 3 +- libs/langchain/langchain/llms/gradient_ai.py | 4 +- .../langchain/llms/huggingface_endpoint.py | 2 +- .../langchain/llms/huggingface_hub.py | 3 +- .../langchain/llms/huggingface_pipeline.py | 5 +- .../llms/huggingface_text_gen_inference.py | 7 +- libs/langchain/langchain/llms/human.py | 3 +- .../langchain/llms/javelin_ai_gateway.py | 3 +- libs/langchain/langchain/llms/llamacpp.py | 9 +- libs/langchain/langchain/llms/manifest.py | 3 +- libs/langchain/langchain/llms/minimax.py | 2 +- .../langchain/llms/mlflow_ai_gateway.py | 3 +- libs/langchain/langchain/llms/modal.py | 2 +- libs/langchain/langchain/llms/mosaicml.py | 2 +- libs/langchain/langchain/llms/nlpcloud.py | 3 +- .../langchain/llms/octoai_endpoint.py | 3 +- libs/langchain/langchain/llms/ollama.py | 8 +- .../langchain/langchain/llms/opaqueprompts.py | 5 +- libs/langchain/langchain/llms/openai.py | 12 +- libs/langchain/langchain/llms/openllm.py | 3 +- libs/langchain/langchain/llms/openlm.py | 3 +- .../langchain/llms/pai_eas_endpoint.py | 4 +- libs/langchain/langchain/llms/petals.py | 3 +- libs/langchain/langchain/llms/pipelineai.py | 3 +- libs/langchain/langchain/llms/predibase.py | 3 +- .../langchain/llms/predictionguard.py | 3 +- .../langchain/llms/promptlayer_openai.py | 3 +- libs/langchain/langchain/llms/replicate.py | 5 +- libs/langchain/langchain/llms/rwkv.py | 3 +- .../langchain/llms/sagemaker_endpoint.py | 3 +- libs/langchain/langchain/llms/self_hosted.py | 3 +- .../llms/self_hosted_hugging_face.py | 3 +- libs/langchain/langchain/llms/stochasticai.py | 2 +- .../langchain/llms/symblai_nebula.py | 6 +- libs/langchain/langchain/llms/textgen.py | 4 +- .../langchain/langchain/llms/titan_takeoff.py | 2 +- .../langchain/llms/titan_takeoff_pro.py | 2 +- libs/langchain/langchain/llms/together.py | 2 +- libs/langchain/langchain/llms/tongyi.py | 4 +- libs/langchain/langchain/llms/vertexai.py | 13 +- libs/langchain/langchain/llms/vllm.py | 5 +- libs/langchain/langchain/llms/writer.py | 2 +- libs/langchain/langchain/llms/yandex.py | 5 +- libs/langchain/langchain/load/__init__.py | 4 +- libs/langchain/langchain/load/dump.py | 27 +- libs/langchain/langchain/load/load.py | 127 +- libs/langchain/langchain/load/serializable.py | 224 +- libs/langchain/langchain/memory/buffer.py | 5 +- .../langchain/memory/buffer_window.py | 3 +- .../langchain/langchain/memory/chat_memory.py | 5 +- .../chat_message_histories/cassandra.py | 8 +- .../chat_message_histories/cosmos_db.py | 8 +- .../memory/chat_message_histories/dynamodb.py | 4 +- .../chat_message_histories/elasticsearch.py | 8 +- .../memory/chat_message_histories/file.py | 8 +- .../chat_message_histories/firestore.py | 8 +- .../chat_message_histories/in_memory.py | 6 +- .../memory/chat_message_histories/momento.py | 9 +- .../memory/chat_message_histories/mongodb.py | 8 +- .../memory/chat_message_histories/neo4j.py | 5 +- .../memory/chat_message_histories/postgres.py | 8 +- .../memory/chat_message_histories/redis.py | 9 +- .../chat_message_histories/rocksetdb.py | 8 +- .../chat_message_histories/singlestoredb.py | 8 +- .../memory/chat_message_histories/sql.py | 11 +- .../chat_message_histories/streamlit.py | 4 +- .../chat_message_histories/upstash_redis.py | 8 +- .../memory/chat_message_histories/xata.py | 8 +- .../memory/chat_message_histories/zep.py | 4 +- libs/langchain/langchain/memory/combined.py | 5 +- libs/langchain/langchain/memory/entity.py | 9 +- libs/langchain/langchain/memory/kg.py | 9 +- .../langchain/memory/motorhead_memory.py | 2 +- libs/langchain/langchain/memory/prompt.py | 2 +- libs/langchain/langchain/memory/readonly.py | 2 +- libs/langchain/langchain/memory/simple.py | 2 +- libs/langchain/langchain/memory/summary.py | 15 +- .../langchain/memory/summary_buffer.py | 5 +- .../langchain/memory/token_buffer.py | 5 +- .../langchain/langchain/memory/vectorstore.py | 7 +- libs/langchain/langchain/model_laboratory.py | 5 +- .../langchain/output_parsers/boolean.py | 2 +- .../langchain/output_parsers/combining.py | 4 +- .../langchain/output_parsers/datetime.py | 3 +- .../langchain/output_parsers/enum.py | 4 +- .../langchain/langchain/output_parsers/fix.py | 9 +- .../langchain/output_parsers/json.py | 3 +- .../langchain/output_parsers/list.py | 90 +- .../output_parsers/openai_functions.py | 10 +- .../langchain/output_parsers/openai_tools.py | 6 +- .../langchain/output_parsers/prompts.py | 2 +- .../langchain/output_parsers/pydantic.py | 5 +- .../langchain/output_parsers/rail_parser.py | 2 +- .../langchain/output_parsers/regex.py | 2 +- .../langchain/output_parsers/regex_dict.py | 2 +- .../langchain/output_parsers/retry.py | 6 +- .../langchain/output_parsers/structured.py | 5 +- .../langchain/langchain/output_parsers/xml.py | 3 +- libs/langchain/langchain/prompts/__init__.py | 23 +- libs/langchain/langchain/prompts/base.py | 190 +- libs/langchain/langchain/prompts/chat.py | 769 +---- .../prompts/example_selector/__init__.py | 11 +- .../prompts/example_selector/base.py | 16 +- .../prompts/example_selector/length_based.py | 66 +- .../prompts/example_selector/ngram_overlap.py | 7 +- .../example_selector/semantic_similarity.py | 174 +- libs/langchain/langchain/prompts/few_shot.py | 342 +- .../prompts/few_shot_with_templates.py | 154 +- libs/langchain/langchain/prompts/loading.py | 165 +- libs/langchain/langchain/prompts/pipeline.py | 57 +- libs/langchain/langchain/prompts/prompt.py | 251 +- libs/langchain/langchain/retrievers/arcee.py | 5 +- libs/langchain/langchain/retrievers/arxiv.py | 3 +- .../retrievers/azure_cognitive_search.py | 4 +- libs/langchain/langchain/retrievers/bm25.py | 3 +- .../langchain/retrievers/chaindesk.py | 2 +- .../retrievers/chatgpt_plugin_retriever.py | 2 +- .../retrievers/cohere_rag_retriever.py | 7 +- .../retrievers/contextual_compression.py | 3 +- .../langchain/retrievers/databerry.py | 2 +- .../langchain/retrievers/docarray.py | 4 +- .../retrievers/document_compressors/base.py | 5 +- .../document_compressors/chain_extract.py | 7 +- .../document_compressors/chain_filter.py | 7 +- .../document_compressors/cohere_rerank.py | 5 +- .../document_compressors/embeddings_filter.py | 6 +- .../retrievers/elastic_search_bm25.py | 3 +- .../langchain/retrievers/ensemble.py | 5 +- .../google_cloud_documentai_warehouse.py | 5 +- .../retrievers/google_vertex_ai_search.py | 5 +- libs/langchain/langchain/retrievers/kay.py | 3 +- libs/langchain/langchain/retrievers/kendra.py | 5 +- libs/langchain/langchain/retrievers/knn.py | 4 +- .../langchain/retrievers/llama_index.py | 5 +- .../langchain/retrievers/merger_retriever.py | 3 +- libs/langchain/langchain/retrievers/metal.py | 5 +- libs/langchain/langchain/retrievers/milvus.py | 7 +- .../langchain/retrievers/multi_query.py | 7 +- .../langchain/retrievers/multi_vector.py | 7 +- .../retrievers/parent_document_retriever.py | 3 +- .../retrievers/pinecone_hybrid_search.py | 7 +- libs/langchain/langchain/retrievers/pubmed.py | 3 +- .../langchain/retrievers/re_phraser.py | 5 +- .../langchain/retrievers/remote_retriever.py | 2 +- .../langchain/retrievers/self_query/base.py | 11 +- libs/langchain/langchain/retrievers/svm.py | 4 +- .../langchain/retrievers/tavily_search_api.py | 5 +- libs/langchain/langchain/retrievers/tfidf.py | 3 +- .../retrievers/time_weighted_retriever.py | 7 +- .../langchain/retrievers/vespa_retriever.py | 3 +- .../retrievers/weaviate_hybrid_search.py | 5 +- .../langchain/retrievers/web_research.py | 9 +- .../langchain/retrievers/wikipedia.py | 3 +- libs/langchain/langchain/retrievers/you.py | 5 +- libs/langchain/langchain/retrievers/zep.py | 5 +- libs/langchain/langchain/retrievers/zilliz.py | 7 +- libs/langchain/langchain/runnables/hub.py | 2 +- .../langchain/runnables/openai_functions.py | 6 +- libs/langchain/langchain/schema/__init__.py | 26 +- libs/langchain/langchain/schema/agent.py | 75 +- libs/langchain/langchain/schema/cache.py | 25 +- .../langchain/schema/callbacks/base.py | 615 +--- .../langchain/schema/callbacks/manager.py | 2126 +----------- .../langchain/schema/callbacks/stdout.py | 98 +- .../schema/callbacks/streaming_stdout.py | 68 +- .../schema/callbacks/tracers/base.py | 538 +-- .../schema/callbacks/tracers/evaluation.py | 226 +- .../schema/callbacks/tracers/langchain.py | 266 +- .../schema/callbacks/tracers/langchain_v1.py | 186 +- .../schema/callbacks/tracers/log_stream.py | 316 +- .../callbacks/tracers/root_listeners.py | 55 +- .../schema/callbacks/tracers/run_collector.py | 53 +- .../schema/callbacks/tracers/schemas.py | 155 +- .../schema/callbacks/tracers/stdout.py | 189 +- libs/langchain/langchain/schema/chat.py | 14 +- .../langchain/schema/chat_history.py | 68 +- libs/langchain/langchain/schema/document.py | 92 +- libs/langchain/langchain/schema/embeddings.py | 28 +- libs/langchain/langchain/schema/exceptions.py | 5 +- .../langchain/schema/language_model.py | 292 +- libs/langchain/langchain/schema/memory.py | 60 +- libs/langchain/langchain/schema/messages.py | 454 +-- libs/langchain/langchain/schema/output.py | 192 +- .../langchain/schema/output_parser.py | 490 +-- libs/langchain/langchain/schema/prompt.py | 29 +- .../langchain/schema/prompt_template.py | 229 +- libs/langchain/langchain/schema/retriever.py | 276 +- .../langchain/schema/runnable/__init__.py | 14 +- .../langchain/schema/runnable/base.py | 3047 +---------------- .../langchain/schema/runnable/branch.py | 255 +- .../langchain/schema/runnable/config.py | 422 +-- .../langchain/schema/runnable/configurable.py | 399 +-- .../langchain/schema/runnable/fallbacks.py | 345 +- .../langchain/schema/runnable/history.py | 289 +- .../langchain/schema/runnable/passthrough.py | 456 +-- .../langchain/schema/runnable/retry.py | 338 +- .../langchain/schema/runnable/router.py | 207 +- .../langchain/schema/runnable/utils.py | 358 +- libs/langchain/langchain/schema/storage.py | 54 +- .../langchain/langchain/schema/vectorstore.py | 703 +--- .../langchain/smith/evaluation/config.py | 8 +- .../langchain/smith/evaluation/progress.py | 5 +- .../smith/evaluation/runner_utils.py | 14 +- .../smith/evaluation/string_run_evaluator.py | 10 +- libs/langchain/langchain/storage/_lc_store.py | 9 +- .../langchain/storage/encoder_backed.py | 2 +- .../langchain/langchain/storage/exceptions.py | 2 +- .../langchain/storage/file_system.py | 3 +- libs/langchain/langchain/storage/in_memory.py | 2 +- libs/langchain/langchain/storage/redis.py | 3 +- .../langchain/storage/upstash_redis.py | 2 +- libs/langchain/langchain/text_splitter.py | 2 +- .../langchain/tools/ainetwork/app.py | 3 +- .../langchain/tools/ainetwork/base.py | 3 +- .../langchain/tools/ainetwork/owner.py | 3 +- .../langchain/tools/ainetwork/rule.py | 3 +- .../langchain/tools/ainetwork/transfer.py | 3 +- .../langchain/tools/ainetwork/value.py | 3 +- .../langchain/langchain/tools/amadeus/base.py | 3 +- .../tools/amadeus/closest_airport.py | 3 +- .../langchain/tools/amadeus/flight_search.py | 3 +- libs/langchain/langchain/tools/arxiv/tool.py | 3 +- .../form_recognizer.py | 3 +- .../image_analysis.py | 3 +- .../azure_cognitive_services/speech2text.py | 3 +- .../azure_cognitive_services/text2speech.py | 3 +- libs/langchain/langchain/tools/base.py | 860 +---- libs/langchain/langchain/tools/bearly/tool.py | 2 +- .../langchain/langchain/tools/clickup/tool.py | 3 +- .../tools/dataforseo_api_search/tool.py | 3 +- .../langchain/tools/ddg_search/tool.py | 3 +- .../langchain/tools/e2b_data_analysis/tool.py | 4 +- .../tools/edenai/audio_speech_to_text.py | 2 +- .../tools/edenai/audio_text_to_speech.py | 2 +- .../tools/edenai/edenai_base_tool.py | 2 +- .../tools/eleven_labs/text2speech.py | 3 +- .../langchain/tools/file_management/copy.py | 3 +- .../langchain/tools/file_management/delete.py | 3 +- .../tools/file_management/file_search.py | 3 +- .../tools/file_management/list_dir.py | 3 +- .../langchain/tools/file_management/move.py | 3 +- .../langchain/tools/file_management/read.py | 3 +- .../langchain/tools/file_management/utils.py | 2 +- .../langchain/tools/file_management/write.py | 3 +- libs/langchain/langchain/tools/github/tool.py | 3 +- libs/langchain/langchain/tools/gitlab/tool.py | 3 +- libs/langchain/langchain/tools/gmail/base.py | 3 +- .../langchain/tools/gmail/create_draft.py | 3 +- .../langchain/tools/gmail/get_message.py | 3 +- .../langchain/tools/gmail/get_thread.py | 3 +- .../langchain/langchain/tools/gmail/search.py | 3 +- .../langchain/tools/gmail/send_message.py | 3 +- .../langchain/tools/google_places/tool.py | 3 +- .../langchain/tools/google_serper/tool.py | 3 +- libs/langchain/langchain/tools/human/tool.py | 3 +- libs/langchain/langchain/tools/jira/tool.py | 3 +- libs/langchain/langchain/tools/json/tool.py | 2 +- .../langchain/tools/memorize/tool.py | 3 +- .../langchain/tools/multion/close_session.py | 3 +- .../langchain/tools/multion/create_session.py | 3 +- .../langchain/tools/multion/update_session.py | 3 +- libs/langchain/langchain/tools/nuclia/tool.py | 2 +- .../langchain/tools/office365/base.py | 3 +- .../tools/office365/create_draft_message.py | 3 +- .../tools/office365/events_search.py | 3 +- .../tools/office365/messages_search.py | 3 +- .../langchain/tools/office365/send_event.py | 3 +- .../langchain/tools/office365/send_message.py | 3 +- .../tools/openapi/utils/api_models.py | 3 +- .../langchain/tools/openweathermap/tool.py | 3 +- .../langchain/tools/playwright/base.py | 3 +- .../langchain/tools/playwright/click.py | 3 +- .../tools/playwright/current_page.py | 3 +- .../tools/playwright/extract_hyperlinks.py | 3 +- .../tools/playwright/extract_text.py | 3 +- .../tools/playwright/get_elements.py | 3 +- .../langchain/tools/playwright/navigate.py | 3 +- .../tools/playwright/navigate_back.py | 3 +- libs/langchain/langchain/tools/plugin.py | 2 +- .../langchain/langchain/tools/powerbi/tool.py | 3 +- libs/langchain/langchain/tools/pubmed/tool.py | 3 +- .../langchain/tools/requests/tool.py | 2 +- libs/langchain/langchain/tools/retriever.py | 5 +- .../langchain/tools/scenexplain/tool.py | 3 +- .../langchain/tools/searchapi/tool.py | 3 +- .../langchain/tools/searx_search/tool.py | 6 +- libs/langchain/langchain/tools/shell/tool.py | 3 +- libs/langchain/langchain/tools/sleep/tool.py | 3 +- .../langchain/tools/spark_sql/tool.py | 6 +- .../langchain/tools/sql_database/tool.py | 6 +- .../tools/steamship_image_generation/tool.py | 3 +- .../langchain/tools/tavily_search/tool.py | 3 +- .../langchain/tools/vectorstore/tool.py | 7 +- .../langchain/tools/yahoo_finance_news.py | 2 +- libs/langchain/langchain/tools/zapier/tool.py | 5 +- .../langchain/utilities/alpha_vantage.py | 2 +- libs/langchain/langchain/utilities/apify.py | 5 +- libs/langchain/langchain/utilities/arcee.py | 5 +- libs/langchain/langchain/utilities/arxiv.py | 4 +- .../langchain/utilities/awslambda.py | 2 +- libs/langchain/langchain/utilities/bibtex.py | 2 +- .../langchain/utilities/bing_search.py | 2 +- .../langchain/utilities/brave_search.py | 5 +- libs/langchain/langchain/utilities/clickup.py | 2 +- .../utilities/dalle_image_generator.py | 3 +- .../utilities/dataforseo_api_search.py | 2 +- .../langchain/utilities/duckduckgo_search.py | 2 +- libs/langchain/langchain/utilities/github.py | 3 +- libs/langchain/langchain/utilities/gitlab.py | 3 +- .../langchain/utilities/golden_query.py | 2 +- .../langchain/utilities/google_places_api.py | 3 +- .../langchain/utilities/google_scholar.py | 3 +- .../langchain/utilities/google_search.py | 3 +- .../langchain/utilities/google_serper.py | 2 +- libs/langchain/langchain/utilities/graphql.py | 2 +- libs/langchain/langchain/utilities/jira.py | 3 +- libs/langchain/langchain/utilities/loading.py | 2 +- .../langchain/utilities/metaphor_search.py | 2 +- libs/langchain/langchain/utilities/openapi.py | 3 +- .../langchain/utilities/openweathermap.py | 3 +- libs/langchain/langchain/utilities/powerbi.py | 3 +- libs/langchain/langchain/utilities/pubmed.py | 4 +- libs/langchain/langchain/utilities/python.py | 2 +- .../langchain/langchain/utilities/requests.py | 3 +- .../langchain/utilities/scenexplain.py | 2 +- .../langchain/utilities/searchapi.py | 2 +- .../langchain/utilities/searx_search.py | 4 +- libs/langchain/langchain/utilities/serpapi.py | 2 +- .../langchain/utilities/tavily_search.py | 2 +- .../utilities/tensorflow_datasets.py | 4 +- libs/langchain/langchain/utilities/twilio.py | 3 +- .../langchain/utilities/wikipedia.py | 4 +- .../langchain/utilities/wolfram_alpha.py | 3 +- libs/langchain/langchain/utilities/zapier.py | 2 +- libs/langchain/langchain/utils/__init__.py | 13 +- libs/langchain/langchain/utils/aiter.py | 210 +- libs/langchain/langchain/utils/formatting.py | 39 +- libs/langchain/langchain/utils/input.py | 48 +- libs/langchain/langchain/utils/iter.py | 176 +- libs/langchain/langchain/utils/loading.py | 55 +- .../langchain/utils/openai_functions.py | 3 +- libs/langchain/langchain/utils/pydantic.py | 15 +- libs/langchain/langchain/utils/utils.py | 199 +- .../langchain/vectorstores/__init__.py | 2 +- .../vectorstores/alibabacloud_opensearch.py | 6 +- .../langchain/vectorstores/analyticdb.py | 5 +- .../langchain/langchain/vectorstores/annoy.py | 4 +- .../langchain/vectorstores/astradb.py | 6 +- .../langchain/langchain/vectorstores/atlas.py | 4 +- .../langchain/langchain/vectorstores/awadb.py | 4 +- .../langchain/vectorstores/azure_cosmos_db.py | 3 +- .../langchain/vectorstores/azuresearch.py | 8 +- .../langchain/vectorstores/bageldb.py | 7 +- .../vectorstores/baiducloud_vector_search.py | 5 +- libs/langchain/langchain/vectorstores/base.py | 2 +- .../langchain/vectorstores/cassandra.py | 5 +- .../langchain/vectorstores/chroma.py | 6 +- .../langchain/vectorstores/clarifai.py | 4 +- .../langchain/vectorstores/clickhouse.py | 7 +- .../langchain/vectorstores/dashvector.py | 4 +- .../langchain/vectorstores/deeplake.py | 5 +- .../langchain/langchain/vectorstores/dingo.py | 4 +- .../langchain/vectorstores/docarray/base.py | 8 +- .../langchain/vectorstores/docarray/hnsw.py | 3 +- .../vectorstores/docarray/in_memory.py | 3 +- .../vectorstores/elastic_vector_search.py | 7 +- .../langchain/vectorstores/elasticsearch.py | 4 +- .../langchain/vectorstores/epsilla.py | 5 +- .../langchain/langchain/vectorstores/faiss.py | 4 +- .../langchain/langchain/vectorstores/hippo.py | 5 +- .../langchain/vectorstores/hologres.py | 5 +- .../langchain/vectorstores/lancedb.py | 5 +- .../langchain/vectorstores/llm_rails.py | 6 +- .../langchain/langchain/vectorstores/marqo.py | 5 +- .../langchain/vectorstores/matching_engine.py | 7 +- .../langchain/vectorstores/meilisearch.py | 5 +- .../langchain/vectorstores/milvus.py | 4 +- .../vectorstores/momento_vector_index.py | 5 +- .../langchain/vectorstores/mongodb_atlas.py | 4 +- .../langchain/vectorstores/myscale.py | 7 +- .../langchain/vectorstores/neo4j_vector.py | 5 +- .../langchain/vectorstores/nucliadb.py | 6 +- .../vectorstores/opensearch_vector_search.py | 6 +- .../langchain/vectorstores/pgembedding.py | 7 +- .../langchain/vectorstores/pgvecto_rs.py | 7 +- .../langchain/vectorstores/pgvector.py | 5 +- .../langchain/vectorstores/pinecone.py | 6 +- .../langchain/vectorstores/qdrant.py | 4 +- .../langchain/vectorstores/redis/base.py | 6 +- .../langchain/vectorstores/redis/schema.py | 2 +- .../langchain/vectorstores/rocksetdb.py | 5 +- .../langchain/langchain/vectorstores/scann.py | 4 +- .../langchain/vectorstores/semadb.py | 6 +- .../langchain/vectorstores/singlestoredb.py | 4 +- .../langchain/vectorstores/sklearn.py | 7 +- .../langchain/vectorstores/sqlitevss.py | 5 +- .../langchain/vectorstores/starrocks.py | 7 +- .../langchain/vectorstores/supabase.py | 6 +- libs/langchain/langchain/vectorstores/tair.py | 5 +- .../langchain/vectorstores/tencentvectordb.py | 6 +- .../langchain/vectorstores/tigris.py | 6 +- .../langchain/vectorstores/tiledb.py | 4 +- .../langchain/vectorstores/timescalevector.py | 7 +- .../langchain/vectorstores/typesense.py | 5 +- .../langchain/vectorstores/usearch.py | 4 +- libs/langchain/langchain/vectorstores/vald.py | 4 +- .../langchain/vectorstores/vearch.py | 4 +- .../langchain/vectorstores/vectara.py | 9 +- .../langchain/langchain/vectorstores/vespa.py | 3 +- .../langchain/vectorstores/weaviate.py | 4 +- libs/langchain/langchain/vectorstores/xata.py | 5 +- libs/langchain/langchain/vectorstores/zep.py | 5 +- .../langchain/vectorstores/zilliz.py | 3 +- libs/langchain/scripts/check_imports.sh | 30 +- libs/langchain/scripts/check_pydantic.sh | 4 +- .../integration_tests/cache/test_cassandra.py | 2 +- .../integration_tests/cache/test_gptcache.py | 2 +- .../cache/test_momento_cache.py | 2 +- .../cache/test_redis_cache.py | 10 +- .../cache/test_upstash_redis_cache.py | 2 +- .../callbacks/test_langchain_tracer.py | 2 +- .../chat_models/test_anthropic.py | 4 +- .../chat_models/test_azure_openai.py | 10 +- .../chat_models/test_azureml_endpoint.py | 11 +- .../chat_models/test_baichuan.py | 3 +- .../chat_models/test_bedrock.py | 4 +- .../chat_models/test_ernie.py | 2 +- .../chat_models/test_fireworks.py | 4 +- .../chat_models/test_google_palm.py | 8 +- .../chat_models/test_hunyuan.py | 3 +- .../chat_models/test_jinachat.py | 8 +- .../chat_models/test_konko.py | 10 +- .../chat_models/test_litellm.py | 11 +- .../chat_models/test_openai.py | 18 +- .../chat_models/test_pai_eas_chat_endpoint.py | 7 +- .../chat_models/test_promptlayer_openai.py | 10 +- .../chat_models/test_qianfan_endpoint.py | 15 +- .../chat_models/test_tongyi.py | 7 +- .../chat_models/test_vertexai.py | 4 +- .../document_loaders/parsers/test_docai.py | 3 +- .../document_loaders/test_arxiv.py | 2 +- .../document_loaders/test_dataframe.py | 2 +- .../document_loaders/test_geodataframe.py | 2 +- .../document_loaders/test_polars_dataframe.py | 2 +- .../document_loaders/test_pubmed.py | 2 +- .../document_loaders/test_quip.py | 2 +- .../test_tensorflow_datasets.py | 4 +- .../document_loaders/test_xorbits.py | 2 +- .../integration_tests/llms/test_anthropic.py | 2 +- .../llms/test_azure_openai.py | 6 +- .../integration_tests/llms/test_chatglm.py | 3 +- .../integration_tests/llms/test_fireworks.py | 12 +- .../llms/test_opaqueprompts.py | 7 +- .../integration_tests/llms/test_openai.py | 2 +- .../llms/test_qianfan_endpoint.py | 2 +- .../llms/test_symblai_nebula.py | 3 +- .../integration_tests/llms/test_tongyi.py | 3 +- .../integration_tests/llms/test_vertexai.py | 2 +- .../memory/chat_message_histories/test_zep.py | 2 +- .../memory/test_cassandra.py | 2 +- .../memory/test_cosmos_db.py | 3 +- .../memory/test_elasticsearch.py | 2 +- .../memory/test_firestore.py | 3 +- .../integration_tests/memory/test_momento.py | 2 +- .../integration_tests/memory/test_mongodb.py | 3 +- .../integration_tests/memory/test_neo4j.py | 3 +- .../integration_tests/memory/test_redis.py | 3 +- .../integration_tests/memory/test_rockset.py | 3 +- .../memory/test_singlestoredb.py | 3 +- .../memory/test_upstash_redis.py | 2 +- .../integration_tests/memory/test_xata.py | 3 +- .../test_ngram_overlap_example_selector.py | 2 +- .../retrievers/docarray/fixtures.py | 3 +- .../document_compressors/test_base.py | 3 +- .../test_chain_extract.py | 3 +- .../document_compressors/test_chain_filter.py | 3 +- .../test_embeddings_filter.py | 2 +- .../retrievers/test_arxiv.py | 2 +- .../retrievers/test_azure_cognitive_search.py | 2 +- .../test_google_docai_warehoure_retriever.py | 3 +- .../test_google_vertex_ai_search.py | 2 +- .../integration_tests/retrievers/test_kay.py | 2 +- .../retrievers/test_pubmed.py | 2 +- .../retrievers/test_wikipedia.py | 2 +- .../integration_tests/retrievers/test_zep.py | 2 +- .../smith/evaluation/test_runner_utils.py | 4 +- .../test_document_transformers.py | 3 +- .../test_nuclia_transformer.py | 2 +- .../tests/integration_tests/test_schema.py | 2 +- .../integration_tests/utilities/test_arxiv.py | 2 +- .../utilities/test_pubmed.py | 2 +- .../utilities/test_tensorflow_datasets.py | 4 +- .../utilities/test_wikipedia_api.py | 2 +- .../vectorstores/conftest.py | 2 +- .../vectorstores/docarray/test_hnsw.py | 2 +- .../vectorstores/docarray/test_in_memory.py | 2 +- .../vectorstores/fake_embeddings.py | 2 +- .../qdrant/async_api/test_from_texts.py | 2 +- .../async_api/test_max_marginal_relevance.py | 2 +- .../async_api/test_similarity_search.py | 2 +- .../vectorstores/qdrant/test_add_texts.py | 2 +- .../qdrant/test_embedding_interface.py | 2 +- .../vectorstores/qdrant/test_from_texts.py | 2 +- .../qdrant/test_max_marginal_relevance.py | 2 +- .../qdrant/test_similarity_search.py | 2 +- .../test_alibabacloud_opensearch.py | 3 +- .../vectorstores/test_astradb.py | 2 +- .../vectorstores/test_dashvector.py | 3 +- .../vectorstores/test_mongodb_atlas.py | 2 +- .../vectorstores/test_zep.py | 2 +- .../tests/mock_servers/robot/server.py | 3 +- .../agents/format_scratchpad/test_log.py | 3 +- .../format_scratchpad/test_log_to_messages.py | 5 +- .../test_openai_functions.py | 5 +- .../agents/format_scratchpad/test_xml.py | 3 +- .../agents/output_parsers/test_json.py | 3 +- .../output_parsers/test_openai_functions.py | 6 +- .../test_react_json_single_input.py | 3 +- .../output_parsers/test_react_single_input.py | 4 +- .../agents/output_parsers/test_self_ask.py | 3 +- .../agents/output_parsers/test_xml.py | 3 +- .../tests/unit_tests/agents/test_chat.py | 3 +- .../tests/unit_tests/agents/test_mrkl.py | 4 +- .../agents/test_mrkl_output_parser.py | 2 +- .../agents/test_openai_functions_multi.py | 4 +- .../tests/unit_tests/agents/test_react.py | 5 +- .../unit_tests/agents/test_structured_chat.py | 11 +- .../callbacks/fake_callback_handler.py | 5 +- .../callbacks/test_callback_manager.py | 2 +- .../unit_tests/callbacks/test_openai_info.py | 2 +- .../callbacks/tracers/test_base_tracer.py | 4 +- .../callbacks/tracers/test_langchain.py | 2 +- .../callbacks/tracers/test_langchain_v1.py | 14 +- .../tests/unit_tests/chains/test_base.py | 2 +- .../chains/test_combine_documents.py | 4 +- .../unit_tests/chains/test_conversation.py | 4 +- .../chains/test_conversation_retrieval.py | 3 +- .../tests/unit_tests/chains/test_graph_qa.py | 2 +- .../tests/unit_tests/chains/test_hyde.py | 4 +- .../tests/unit_tests/chains/test_llm.py | 4 +- .../tests/unit_tests/chains/test_memory.py | 2 +- .../unit_tests/chat_loaders/test_telegram.py | 2 +- .../unit_tests/chat_models/test_anthropic.py | 2 +- .../chat_models/test_azureml_endpoint.py | 2 +- .../unit_tests/chat_models/test_baichuan.py | 18 +- .../unit_tests/chat_models/test_ernie.py | 6 +- .../unit_tests/chat_models/test_fireworks.py | 2 +- .../chat_models/test_google_palm.py | 2 +- .../unit_tests/chat_models/test_hunyuan.py | 18 +- .../chat_models/test_javelin_ai_gateway.py | 2 +- .../unit_tests/chat_models/test_openai.py | 8 +- .../unit_tests/docstore/test_arbitrary_fn.py | 3 +- .../document_loaders/parsers/test_generic.py | 2 +- .../unit_tests/document_loaders/test_base.py | 3 +- .../document_loaders/test_generic_loader.py | 2 +- .../test_beautiful_soup_transformer.py | 2 +- .../unit_tests/embeddings/test_caching.py | 2 +- .../evaluation/agents/test_eval_chain.py | 4 +- .../indexes/test_hashed_document.py | 2 +- .../tests/unit_tests/indexes/test_indexing.py | 4 +- .../tests/unit_tests/llms/fake_chat_model.py | 5 +- .../tests/unit_tests/llms/fake_llm.py | 3 +- .../tests/unit_tests/llms/test_ai21.py | 2 +- .../tests/unit_tests/llms/test_aleph_alpha.py | 2 +- .../tests/unit_tests/llms/test_anyscale.py | 2 +- .../tests/unit_tests/llms/test_base.py | 2 +- .../tests/unit_tests/llms/test_callbacks.py | 3 +- .../tests/unit_tests/llms/test_fireworks.py | 2 +- .../tests/unit_tests/llms/test_gooseai.py | 2 +- .../unit_tests/llms/test_symblai_nebula.py | 2 +- .../load/__snapshots__/test_dump.ambr | 10 +- .../tests/unit_tests/load/test_dump.py | 8 +- .../tests/unit_tests/load/test_load.py | 6 +- .../chat_message_histories/test_file.py | 2 +- .../memory/chat_message_histories/test_sql.py | 2 +- .../chat_message_histories/test_streamlit.py | 2 +- .../unit_tests/memory/test_combined_memory.py | 2 +- .../output_parsers/test_enum_parser.py | 3 +- .../unit_tests/output_parsers/test_json.py | 2 +- .../output_parsers/test_openai_functions.py | 4 +- .../output_parsers/test_pydantic_parser.py | 5 +- .../output_parsers/test_structured_parser.py | 3 +- .../retrievers/self_query/test_base.py | 2 +- .../retrievers/sequential_retriever.py | 2 +- .../tests/unit_tests/retrievers/test_base.py | 2 +- .../tests/unit_tests/retrievers/test_bm25.py | 2 +- .../unit_tests/retrievers/test_ensemble.py | 2 +- .../unit_tests/retrievers/test_multi_query.py | 2 +- .../retrievers/test_remote_retriever.py | 2 +- .../tests/unit_tests/retrievers/test_svm.py | 2 +- .../tests/unit_tests/retrievers/test_tfidf.py | 2 +- .../test_time_weighted_retriever.py | 6 +- .../tests/unit_tests/retrievers/test_you.py | 2 +- .../tests/unit_tests/runnables/test_hub.py | 9 +- .../runnables/test_openai_functions.py | 6 +- .../tests/unit_tests/schema/test_imports.py | 2 +- .../tests/unit_tests/schema/test_messages.py | 3 +- .../tests/unit_tests/schema/test_output.py | 4 +- .../smith/evaluation/test_runner_utils.py | 2 +- .../tests/unit_tests/storage/test_lc_store.py | 2 +- libs/langchain/tests/unit_tests/test_cache.py | 13 +- .../tests/unit_tests/test_dependencies.py | 3 +- .../tests/unit_tests/test_formatting.py | 3 +- .../tests/unit_tests/test_globals.py | 6 +- .../langchain/tests/unit_tests/test_schema.py | 15 +- libs/langchain/tests/unit_tests/test_utils.py | 3 +- .../tools/openapi/test_api_models.py | 2 +- .../tests/unit_tests/tools/test_exported.py | 8 +- .../unit_tests/utilities/test_loading.py | 3 +- .../tests/unit_tests/utils/test_iter.py | 3 +- .../unit_tests/utils/test_openai_functions.py | 3 +- .../vectorstores/redis/test_redis_schema.py | 2 +- .../unit_tests/vectorstores/test_imports.py | 3 +- 1153 files changed, 28106 insertions(+), 22917 deletions(-) create mode 100644 .github/workflows/langchain_core_ci.yml create mode 100644 .github/workflows/langchain_core_release.yml create mode 100644 libs/core/Makefile create mode 100644 libs/core/README.md create mode 100644 libs/core/langchain_core/__init__.py create mode 100644 libs/core/langchain_core/_api/__init__.py create mode 100644 libs/core/langchain_core/_api/deprecation.py create mode 100644 libs/core/langchain_core/_api/path.py rename libs/{langchain/tests/unit_tests/_api => core/langchain_core/callbacks}/__init__.py (100%) create mode 100644 libs/core/langchain_core/callbacks/base.py create mode 100644 libs/core/langchain_core/callbacks/manager.py create mode 100644 libs/core/langchain_core/callbacks/stdout.py create mode 100644 libs/core/langchain_core/callbacks/streaming_stdout.py rename libs/{langchain/tests/unit_tests/schema/runnable => core/langchain_core/callbacks/tracers}/__init__.py (100%) create mode 100644 libs/core/langchain_core/callbacks/tracers/base.py create mode 100644 libs/core/langchain_core/callbacks/tracers/evaluation.py create mode 100644 libs/core/langchain_core/callbacks/tracers/langchain.py create mode 100644 libs/core/langchain_core/callbacks/tracers/langchain_v1.py create mode 100644 libs/core/langchain_core/callbacks/tracers/log_stream.py create mode 100644 libs/core/langchain_core/callbacks/tracers/root_listeners.py create mode 100644 libs/core/langchain_core/callbacks/tracers/run_collector.py create mode 100644 libs/core/langchain_core/callbacks/tracers/schemas.py create mode 100644 libs/core/langchain_core/callbacks/tracers/stdout.py create mode 100644 libs/core/langchain_core/chat_model.py create mode 100644 libs/core/langchain_core/env.py create mode 100644 libs/core/langchain_core/globals/__init__.py create mode 100644 libs/core/langchain_core/llm.py create mode 100644 libs/core/langchain_core/load/__init__.py create mode 100644 libs/core/langchain_core/load/dump.py create mode 100644 libs/core/langchain_core/load/load.py create mode 100644 libs/core/langchain_core/load/serializable.py create mode 100644 libs/core/langchain_core/output_parsers/__init__.py create mode 100644 libs/core/langchain_core/output_parsers/list.py create mode 100644 libs/core/langchain_core/prompts/__init__.py create mode 100644 libs/core/langchain_core/prompts/base.py create mode 100644 libs/core/langchain_core/prompts/chat.py create mode 100644 libs/core/langchain_core/prompts/example_selector/__init__.py create mode 100644 libs/core/langchain_core/prompts/example_selector/base.py create mode 100644 libs/core/langchain_core/prompts/example_selector/length_based.py create mode 100644 libs/core/langchain_core/prompts/example_selector/semantic_similarity.py create mode 100644 libs/core/langchain_core/prompts/few_shot.py create mode 100644 libs/core/langchain_core/prompts/few_shot_with_templates.py create mode 100644 libs/core/langchain_core/prompts/loading.py create mode 100644 libs/core/langchain_core/prompts/pipeline.py create mode 100644 libs/core/langchain_core/prompts/prompt.py create mode 100644 libs/core/langchain_core/pydantic_v1/__init__.py create mode 100644 libs/core/langchain_core/pydantic_v1/dataclasses.py create mode 100644 libs/core/langchain_core/pydantic_v1/main.py create mode 100644 libs/core/langchain_core/runnables/__init__.py create mode 100644 libs/core/langchain_core/runnables/base.py create mode 100644 libs/core/langchain_core/runnables/branch.py create mode 100644 libs/core/langchain_core/runnables/config.py create mode 100644 libs/core/langchain_core/runnables/configurable.py create mode 100644 libs/core/langchain_core/runnables/fallbacks.py create mode 100644 libs/core/langchain_core/runnables/history.py create mode 100644 libs/core/langchain_core/runnables/passthrough.py create mode 100644 libs/core/langchain_core/runnables/retry.py create mode 100644 libs/core/langchain_core/runnables/router.py create mode 100644 libs/core/langchain_core/runnables/utils.py create mode 100644 libs/core/langchain_core/schema/__init__.py create mode 100644 libs/core/langchain_core/schema/agent.py create mode 100644 libs/core/langchain_core/schema/cache.py create mode 100644 libs/core/langchain_core/schema/chat.py create mode 100644 libs/core/langchain_core/schema/chat_history.py create mode 100644 libs/core/langchain_core/schema/document.py create mode 100644 libs/core/langchain_core/schema/embeddings.py create mode 100644 libs/core/langchain_core/schema/exceptions.py create mode 100644 libs/core/langchain_core/schema/language_model.py create mode 100644 libs/core/langchain_core/schema/memory.py create mode 100644 libs/core/langchain_core/schema/messages.py create mode 100644 libs/core/langchain_core/schema/output.py create mode 100644 libs/core/langchain_core/schema/output_parser.py create mode 100644 libs/core/langchain_core/schema/prompt.py create mode 100644 libs/core/langchain_core/schema/prompt_template.py create mode 100644 libs/core/langchain_core/schema/retriever.py create mode 100644 libs/core/langchain_core/schema/storage.py create mode 100644 libs/core/langchain_core/schema/vectorstore.py create mode 100644 libs/core/langchain_core/tool.py create mode 100644 libs/core/langchain_core/utils/__init__.py create mode 100644 libs/core/langchain_core/utils/aiter.py create mode 100644 libs/core/langchain_core/utils/formatting.py create mode 100644 libs/core/langchain_core/utils/input.py create mode 100644 libs/core/langchain_core/utils/iter.py create mode 100644 libs/core/langchain_core/utils/loading.py create mode 100644 libs/core/langchain_core/utils/pydantic.py create mode 100644 libs/core/langchain_core/utils/utils.py create mode 100644 libs/core/poetry.lock create mode 100644 libs/core/pyproject.toml create mode 100644 libs/core/tests/__init__.py create mode 100644 libs/core/tests/unit_tests/__init__.py create mode 100644 libs/core/tests/unit_tests/_api/__init__.py rename libs/{langchain => core}/tests/unit_tests/_api/test_deprecation.py (98%) rename libs/{langchain => core}/tests/unit_tests/_api/test_imports.py (86%) rename libs/{langchain => core}/tests/unit_tests/_api/test_path.py (87%) create mode 100644 libs/core/tests/unit_tests/data/prompt_file.txt create mode 100644 libs/core/tests/unit_tests/data/prompts/prompt_extra_args.json create mode 100644 libs/core/tests/unit_tests/data/prompts/prompt_missing_args.json create mode 100644 libs/core/tests/unit_tests/data/prompts/simple_prompt.json create mode 100644 libs/core/tests/unit_tests/examples/example-non-utf8.csv create mode 100644 libs/core/tests/unit_tests/examples/example-non-utf8.txt create mode 100644 libs/core/tests/unit_tests/examples/example-utf8.csv create mode 100644 libs/core/tests/unit_tests/examples/example-utf8.txt rename libs/{langchain => core}/tests/unit_tests/examples/example_prompt.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/examples.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/examples.yaml (100%) rename libs/{langchain => core}/tests/unit_tests/examples/few_shot_prompt.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/few_shot_prompt.yaml (100%) rename libs/{langchain => core}/tests/unit_tests/examples/few_shot_prompt_example_prompt.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/few_shot_prompt_examples_in.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/few_shot_prompt_yaml_examples.yaml (100%) rename libs/{langchain => core}/tests/unit_tests/examples/jinja_injection_prompt.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/jinja_injection_prompt.yaml (100%) rename libs/{langchain => core}/tests/unit_tests/examples/prompt_with_output_parser.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/simple_prompt.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/simple_prompt.yaml (100%) rename libs/{langchain => core}/tests/unit_tests/examples/simple_prompt_with_template_file.json (100%) rename libs/{langchain => core}/tests/unit_tests/examples/simple_template.txt (100%) create mode 100644 libs/core/tests/unit_tests/fake/__init__.py create mode 100644 libs/core/tests/unit_tests/fake/callbacks.py create mode 100644 libs/core/tests/unit_tests/fake/chat_model.py create mode 100644 libs/core/tests/unit_tests/fake/llm.py create mode 100644 libs/core/tests/unit_tests/fake/memory.py create mode 100644 libs/core/tests/unit_tests/prompt_file.txt rename libs/{langchain => core}/tests/unit_tests/prompts/__init__.py (100%) create mode 100644 libs/core/tests/unit_tests/prompts/prompt_extra_args.json create mode 100644 libs/core/tests/unit_tests/prompts/prompt_missing_args.json create mode 100644 libs/core/tests/unit_tests/prompts/simple_prompt.json rename libs/{langchain => core}/tests/unit_tests/prompts/test_chat.py (98%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_few_shot.py (97%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_few_shot_with_templates.py (93%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_imports.py (90%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_length_based_example_selector.py (92%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_loading.py (87%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_pipeline_prompt.py (88%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_prompt.py (99%) rename libs/{langchain => core}/tests/unit_tests/prompts/test_utils.py (76%) create mode 100644 libs/core/tests/unit_tests/runnable/__init__.py rename libs/{langchain/tests/unit_tests/schema => core/tests/unit_tests}/runnable/__snapshots__/test_runnable.ambr (72%) rename libs/{langchain/tests/unit_tests/schema => core/tests/unit_tests}/runnable/test_config.py (75%) rename libs/{langchain/tests/unit_tests/schema => core/tests/unit_tests}/runnable/test_history.py (95%) rename libs/{langchain/tests/unit_tests/schema => core/tests/unit_tests}/runnable/test_runnable.py (88%) rename libs/{langchain/tests/unit_tests/schema => core/tests/unit_tests}/runnable/test_utils.py (95%) create mode 100644 libs/core/tests/unit_tests/schema/__init__.py create mode 100644 libs/core/tests/unit_tests/schema/test_imports.py create mode 100644 libs/core/tests/unit_tests/schema/test_messages.py create mode 100644 libs/core/tests/unit_tests/schema/test_output.py create mode 100644 libs/core/tests/unit_tests/test_globals.py rename libs/{langchain/tests/unit_tests/tools/test_base.py => core/tests/unit_tests/test_tool.py} (98%) create mode 100644 libs/core/tests/unit_tests/utils/__init__.py create mode 100644 libs/core/tests/unit_tests/utils/test_imports.py create mode 100644 libs/core/tests/unit_tests/utils/test_iter.py diff --git a/.github/workflows/_compile_integration_test.yml b/.github/workflows/_compile_integration_test.yml index d52828084f5..71c1721dd78 100644 --- a/.github/workflows/_compile_integration_test.yml +++ b/.github/workflows/_compile_integration_test.yml @@ -7,6 +7,10 @@ on: required: true type: string description: "From which folder this pipeline executes" + langchain-core-location: + required: false + type: string + description: "Relative path to the langchain core library folder" env: POETRY_VERSION: "1.6.1" @@ -40,6 +44,14 @@ jobs: shell: bash run: poetry install --with=test_integration + - name: Install langchain core editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-core-location }} + env: + LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }} + run: | + poetry run pip install -e "$LANGCHAIN_CORE_LOCATION" + - name: Check integration tests compile shell: bash run: poetry run pytest -m compile tests/integration_tests diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml index f1f5d5c758e..9575b69fd35 100644 --- a/.github/workflows/_lint.yml +++ b/.github/workflows/_lint.yml @@ -11,6 +11,10 @@ on: required: false type: string description: "Relative path to the langchain library folder" + langchain-core-location: + required: false + type: string + description: "Relative path to the langchain core library folder" env: POETRY_VERSION: "1.6.1" @@ -76,7 +80,15 @@ jobs: env: LANGCHAIN_LOCATION: ${{ inputs.langchain-location }} run: | - pip install -e "$LANGCHAIN_LOCATION" + poetry run pip install -e "$LANGCHAIN_LOCATION" + + - name: Install langchain core editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-core-location }} + env: + LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }} + run: | + poetry run pip install -e "$LANGCHAIN_CORE_LOCATION" - name: Get .mypy_cache to speed up mypy uses: actions/cache@v3 diff --git a/.github/workflows/_pydantic_compatibility.yml b/.github/workflows/_pydantic_compatibility.yml index 8948836237c..af943fbdfde 100644 --- a/.github/workflows/_pydantic_compatibility.yml +++ b/.github/workflows/_pydantic_compatibility.yml @@ -7,6 +7,14 @@ on: required: true type: string description: "From which folder this pipeline executes" + langchain-location: + required: false + type: string + description: "Relative path to the langchain library folder" + langchain-core-location: + required: false + type: string + description: "Relative path to the langchain core library folder" env: POETRY_VERSION: "1.6.1" @@ -40,6 +48,22 @@ jobs: shell: bash run: poetry install + - name: Install langchain editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-location }} + env: + LANGCHAIN_LOCATION: ${{ inputs.langchain-location }} + run: | + poetry run pip install -e "$LANGCHAIN_LOCATION" + + - name: Install langchain core editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-core-location }} + env: + LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }} + run: | + poetry run pip install -e "$LANGCHAIN_CORE_LOCATION" + - name: Install the opposite major version of pydantic # If normal tests use pydantic v1, here we'll use v2, and vice versa. shell: bash diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index a122a40058f..4172b899c31 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -7,6 +7,14 @@ on: required: true type: string description: "From which folder this pipeline executes" + langchain-location: + required: false + type: string + description: "Relative path to the langchain library folder" + langchain-core-location: + required: false + type: string + description: "Relative path to the langchain core library folder" env: POETRY_VERSION: "1.6.1" @@ -40,9 +48,26 @@ jobs: shell: bash run: poetry install + - name: Install langchain editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-location }} + env: + LANGCHAIN_LOCATION: ${{ inputs.langchain-location }} + run: | + poetry run pip install -e "$LANGCHAIN_LOCATION" + + - name: Install langchain core editable + working-directory: ${{ inputs.working-directory }} + if: ${{ inputs.langchain-core-location }} + env: + LANGCHAIN_CORE_LOCATION: ${{ inputs.langchain-core-location }} + run: | + poetry run pip install -e "$LANGCHAIN_CORE_LOCATION" + - name: Run core tests shell: bash - run: make test + run: | + make test - name: Ensure the tests did not create any additional files shell: bash diff --git a/.github/workflows/langchain_ci.yml b/.github/workflows/langchain_ci.yml index 7e9b4a6e625..7e486fffe19 100644 --- a/.github/workflows/langchain_ci.yml +++ b/.github/workflows/langchain_ci.yml @@ -36,6 +36,7 @@ jobs: ./.github/workflows/_lint.yml with: working-directory: libs/langchain + langchain-core-location: ../core secrets: inherit test: @@ -43,6 +44,7 @@ jobs: ./.github/workflows/_test.yml with: working-directory: libs/langchain + langchain-core-location: ../core secrets: inherit compile-integration-tests: @@ -50,6 +52,7 @@ jobs: ./.github/workflows/_compile_integration_test.yml with: working-directory: libs/langchain + langchain-core-location: ../core secrets: inherit pydantic-compatibility: @@ -57,6 +60,7 @@ jobs: ./.github/workflows/_pydantic_compatibility.yml with: working-directory: libs/langchain + langchain-core-location: ../core secrets: inherit extended-tests: @@ -89,6 +93,11 @@ jobs: echo "Running extended tests, installing dependencies with poetry..." poetry install -E extended_testing + - name: Install langchain core editable + shell: bash + run: | + poetry run pip install -e ../core + - name: Run extended tests run: make extended_tests diff --git a/.github/workflows/langchain_core_ci.yml b/.github/workflows/langchain_core_ci.yml new file mode 100644 index 00000000000..dc035ebfeb6 --- /dev/null +++ b/.github/workflows/langchain_core_ci.yml @@ -0,0 +1,52 @@ +--- +name: libs/langchain core CI + +on: + push: + branches: [ master ] + pull_request: + paths: + - '.github/actions/poetry_setup/action.yml' + - '.github/tools/**' + - '.github/workflows/_lint.yml' + - '.github/workflows/_test.yml' + - '.github/workflows/_pydantic_compatibility.yml' + - '.github/workflows/langchain_core_ci.yml' + - 'libs/core/**' + workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI + +# If another push to the same PR or branch happens while this workflow is still running, +# cancel the earlier run in favor of the next run. +# +# There's no point in testing an outdated version of the code. GitHub only allows +# a limited number of job runners to be active at the same time, so it's better to cancel +# pointless jobs early so that more useful jobs can run sooner. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + POETRY_VERSION: "1.6.1" + WORKDIR: "libs/core" + +jobs: + lint: + uses: + ./.github/workflows/_lint.yml + with: + working-directory: libs/core + secrets: inherit + + test: + uses: + ./.github/workflows/_test.yml + with: + working-directory: libs/core + secrets: inherit + + pydantic-compatibility: + uses: + ./.github/workflows/_pydantic_compatibility.yml + with: + working-directory: libs/core + secrets: inherit diff --git a/.github/workflows/langchain_core_release.yml b/.github/workflows/langchain_core_release.yml new file mode 100644 index 00000000000..244c292c2e3 --- /dev/null +++ b/.github/workflows/langchain_core_release.yml @@ -0,0 +1,13 @@ +--- +name: libs/core Release + +on: + workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI + +jobs: + release: + uses: + ./.github/workflows/_release.yml + with: + working-directory: libs/core + secrets: inherit diff --git a/.github/workflows/langchain_experimental_ci.yml b/.github/workflows/langchain_experimental_ci.yml index 7750f0bcffe..11ea34e50c8 100644 --- a/.github/workflows/langchain_experimental_ci.yml +++ b/.github/workflows/langchain_experimental_ci.yml @@ -36,6 +36,7 @@ jobs: with: working-directory: libs/experimental langchain-location: ../langchain + langchain-core-location: ../core secrets: inherit test: @@ -43,6 +44,8 @@ jobs: ./.github/workflows/_test.yml with: working-directory: libs/experimental + langchain-location: ../langchain + langchain-core-location: ../core secrets: inherit compile-integration-tests: @@ -88,6 +91,7 @@ jobs: echo "Editably installing langchain outside of poetry, to avoid messing up lockfile..." poetry run pip install -e ../langchain + poetry run pip install -e ../core - name: Run tests run: make test diff --git a/docs/api_reference/create_api_rst.py b/docs/api_reference/create_api_rst.py index 3aeb4b0d300..1ac4af2273b 100644 --- a/docs/api_reference/create_api_rst.py +++ b/docs/api_reference/create_api_rst.py @@ -13,8 +13,10 @@ HERE = Path(__file__).parent PKG_DIR = ROOT_DIR / "libs" / "langchain" / "langchain" EXP_DIR = ROOT_DIR / "libs" / "experimental" / "langchain_experimental" +CORE_DIR = ROOT_DIR / "libs" / "core" / "langchain_core" WRITE_FILE = HERE / "api_reference.rst" EXP_WRITE_FILE = HERE / "experimental_api_reference.rst" +CORE_WRITE_FILE = HERE / "core_api_reference.rst" ClassKind = Literal["TypedDict", "Regular", "Pydantic", "enum"] @@ -292,6 +294,17 @@ def _document_langchain_experimental() -> None: def _document_langchain_core() -> None: + """Document the langchain_core package.""" + # Generate core_api_reference.rst + core_members = _load_package_modules(EXP_DIR) + core_doc = ".. _core_api_reference:\n\n" + _construct_doc( + "langchain_core", core_members + ) + with open(CORE_WRITE_FILE, "w") as f: + f.write(core_doc) + + +def _document_langchain() -> None: """Document the main langchain package.""" # load top level module members lc_members = _load_package_modules(PKG_DIR) @@ -306,7 +319,6 @@ def _document_langchain_core() -> None: "agents.output_parsers": agents["output_parsers"], "agents.format_scratchpad": agents["format_scratchpad"], "tools.render": tools["render"], - "schema.runnable": schema["runnable"], } ) @@ -318,8 +330,9 @@ def _document_langchain_core() -> None: def main() -> None: """Generate the reference.rst file for each package.""" - _document_langchain_core() + _document_langchain() _document_langchain_experimental() + _document_langchain_core() if __name__ == "__main__": diff --git a/docs/api_reference/themes/scikit-learn-modern/nav.html b/docs/api_reference/themes/scikit-learn-modern/nav.html index 6730903f370..37c59b466d5 100644 --- a/docs/api_reference/themes/scikit-learn-modern/nav.html +++ b/docs/api_reference/themes/scikit-learn-modern/nav.html @@ -34,6 +34,9 @@ + diff --git a/libs/core/Makefile b/libs/core/Makefile new file mode 100644 index 00000000000..0e6253395de --- /dev/null +++ b/libs/core/Makefile @@ -0,0 +1,54 @@ +.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests + +# Default target executed when no arguments are given to make. +all: help + +# Define a variable for the test file path. +TEST_FILE ?= tests/unit_tests/ + +test: + poetry run pytest $(TEST_FILE) + +tests: + poetry run pytest $(TEST_FILE) + +test_watch: + poetry run ptw --snapshot-update --now . -- -x tests/unit_tests + + +###################### +# LINTING AND FORMATTING +###################### + +# Define a variable for Python and notebook files. +PYTHON_FILES=. +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') + +lint lint_diff: + poetry run ruff . + [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) + +format format_diff: + poetry run ruff format $(PYTHON_FILES) + poetry run ruff --select I --fix $(PYTHON_FILES) + +spell_check: + poetry run codespell --toml pyproject.toml + +spell_fix: + poetry run codespell --toml pyproject.toml -w + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'test - run unit tests' + @echo 'tests - run unit tests' + @echo 'test TEST_FILE= - run all tests in file' + @echo 'test_watch - run unit tests in watch mode' diff --git a/libs/core/README.md b/libs/core/README.md new file mode 100644 index 00000000000..ef81c1da1ee --- /dev/null +++ b/libs/core/README.md @@ -0,0 +1 @@ +# langchain-core diff --git a/libs/core/langchain_core/__init__.py b/libs/core/langchain_core/__init__.py new file mode 100644 index 00000000000..ee8f8def005 --- /dev/null +++ b/libs/core/langchain_core/__init__.py @@ -0,0 +1,7 @@ +from importlib import metadata + +try: + __version__ = metadata.version(__package__) +except metadata.PackageNotFoundError: + # Case where package metadata is not available. + __version__ = "" diff --git a/libs/core/langchain_core/_api/__init__.py b/libs/core/langchain_core/_api/__init__.py new file mode 100644 index 00000000000..e013a72129f --- /dev/null +++ b/libs/core/langchain_core/_api/__init__.py @@ -0,0 +1,26 @@ +"""Helper functions for managing the LangChain API. + +This module is only relevant for LangChain developers, not for users. + +.. warning:: + + This module and its submodules are for internal use only. Do not use them + in your own code. We may change the API at any time with no warning. + +""" + +from .deprecation import ( + LangChainDeprecationWarning, + deprecated, + suppress_langchain_deprecation_warning, + surface_langchain_deprecation_warnings, + warn_deprecated, +) + +__all__ = [ + "deprecated", + "LangChainDeprecationWarning", + "suppress_langchain_deprecation_warning", + "surface_langchain_deprecation_warnings", + "warn_deprecated", +] diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py new file mode 100644 index 00000000000..6919504351d --- /dev/null +++ b/libs/core/langchain_core/_api/deprecation.py @@ -0,0 +1,341 @@ +"""Helper functions for deprecating parts of the LangChain API. + +This module was adapted from matplotlibs _api/deprecation.py module: + +https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py + +.. warning:: + + This module is for internal use only. Do not use it in your own code. + We may change the API at any time with no warning. +""" + +import contextlib +import functools +import inspect +import warnings +from typing import Any, Callable, Generator, Type, TypeVar + + +class LangChainDeprecationWarning(DeprecationWarning): + """A class for issuing deprecation warnings for LangChain users.""" + + +class LangChainPendingDeprecationWarning(PendingDeprecationWarning): + """A class for issuing deprecation warnings for LangChain users.""" + + +# PUBLIC API + + +T = TypeVar("T", Type, Callable) + + +def deprecated( + since: str, + *, + message: str = "", + name: str = "", + alternative: str = "", + pending: bool = False, + obj_type: str = "", + addendum: str = "", + removal: str = "", +) -> Callable[[T], T]: + """Decorator to mark a function, a class, or a property as deprecated. + + When deprecating a classmethod, a staticmethod, or a property, the + ``@deprecated`` decorator should go *under* ``@classmethod`` and + ``@staticmethod`` (i.e., `deprecated` should directly decorate the + underlying callable), but *over* ``@property``. + + When deprecating a class ``C`` intended to be used as a base class in a + multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method + (if ``C`` instead inherited its ``__init__`` from its own base class, then + ``@deprecated`` would mess up ``__init__`` inheritance when installing its + own (deprecation-emitting) ``C.__init__``). + + Parameters are the same as for `warn_deprecated`, except that *obj_type* + defaults to 'class' if decorating a class, 'attribute' if decorating a + property, and 'function' otherwise. + + Arguments: + since : str + The release at which this API became deprecated. + message : str, optional + Override the default deprecation message. The %(since)s, + %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, + and %(removal)s format specifiers will be replaced by the + values of the respective arguments passed to this function. + name : str, optional + The name of the deprecated object. + alternative : str, optional + An alternative API that the user may use in place of the + deprecated API. The deprecation warning will tell the user + about this alternative if provided. + pending : bool, optional + If True, uses a PendingDeprecationWarning instead of a + DeprecationWarning. Cannot be used together with removal. + obj_type : str, optional + The object type being deprecated. + addendum : str, optional + Additional text appended directly to the final message. + removal : str, optional + The expected removal version. With the default (an empty + string), a removal version is automatically computed from + since. Set to other Falsy values to not schedule a removal + date. Cannot be used together with pending. + + Examples + -------- + + .. code-block:: python + + @deprecated('1.4.0') + def the_function_to_deprecate(): + pass + """ + + def deprecate( + obj: T, + *, + _obj_type: str = obj_type, + _name: str = name, + _message: str = message, + _alternative: str = alternative, + _pending: bool = pending, + _addendum: str = addendum, + ) -> T: + """Implementation of the decorator returned by `deprecated`.""" + if isinstance(obj, type): + if not _obj_type: + _obj_type = "class" + wrapped = obj.__init__ # type: ignore + _name = _name or obj.__name__ + old_doc = obj.__doc__ + + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + """Finalize the deprecation of a class.""" + try: + obj.__doc__ = new_doc + except AttributeError: # Can't set on some extension objects. + pass + obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] + wrapper + ) + return obj + + elif isinstance(obj, property): + if not _obj_type: + _obj_type = "attribute" + wrapped = None + _name = _name or obj.fget.__name__ + old_doc = obj.__doc__ + + class _deprecated_property(type(obj)): # type: ignore + """A deprecated property.""" + + def __get__(self, instance, owner=None): # type: ignore + if instance is not None or owner is not None: + emit_warning() + return super().__get__(instance, owner) + + def __set__(self, instance, value): # type: ignore + if instance is not None: + emit_warning() + return super().__set__(instance, value) + + def __delete__(self, instance): # type: ignore + if instance is not None: + emit_warning() + return super().__delete__(instance) + + def __set_name__(self, owner, set_name): # type: ignore + nonlocal _name + if _name == "": + _name = set_name + + def finalize(_: Any, new_doc: str) -> Any: # type: ignore + """Finalize the property.""" + return _deprecated_property( + fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc + ) + + else: + if not _obj_type: + _obj_type = "function" + wrapped = obj + _name = _name or obj.__name__ # type: ignore + old_doc = wrapped.__doc__ + + def finalize( # type: ignore + wrapper: Callable[..., Any], new_doc: str + ) -> T: + """Wrap the wrapped function using the wrapper and update the docstring. + + Args: + wrapper: The wrapper function. + new_doc: The new docstring. + + Returns: + The wrapped function. + """ + wrapper = functools.wraps(wrapped)(wrapper) + wrapper.__doc__ = new_doc + return wrapper + + def emit_warning() -> None: + """Emit the warning.""" + warn_deprecated( + since, + message=_message, + name=_name, + alternative=_alternative, + pending=_pending, + obj_type=_obj_type, + addendum=_addendum, + removal=removal, + ) + + def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any: + """Wrapper for the original wrapped callable that emits a warning. + + Args: + *args: The positional arguments to the function. + **kwargs: The keyword arguments to the function. + + Returns: + The return value of the function being wrapped. + """ + emit_warning() + return wrapped(*args, **kwargs) + + old_doc = inspect.cleandoc(old_doc or "").strip("\n") + + if not old_doc: + new_doc = "[*Deprecated*]" + else: + new_doc = f"[*Deprecated*] {old_doc}" + + # Modify the docstring to include a deprecation notice. + notes_header = "\nNotes\n-----" + components = [ + message, + f"Use {alternative} instead." if alternative else "", + addendum, + ] + details = " ".join([component.strip() for component in components if component]) + new_doc += ( + f"[*Deprecated*] {old_doc}\n" + f"{notes_header if notes_header not in old_doc else ''}\n" + f".. deprecated:: {since}\n" + f" {details}" + ) + + return finalize(warning_emitting_wrapper, new_doc) + + return deprecate + + +@contextlib.contextmanager +def suppress_langchain_deprecation_warning() -> Generator[None, None, None]: + """Context manager to suppress LangChainDeprecationWarning.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LangChainDeprecationWarning) + warnings.simplefilter("ignore", LangChainPendingDeprecationWarning) + yield + + +def warn_deprecated( + since: str, + *, + message: str = "", + name: str = "", + alternative: str = "", + pending: bool = False, + obj_type: str = "", + addendum: str = "", + removal: str = "", +) -> None: + """Display a standardized deprecation. + + Arguments: + since : str + The release at which this API became deprecated. + message : str, optional + Override the default deprecation message. The %(since)s, + %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, + and %(removal)s format specifiers will be replaced by the + values of the respective arguments passed to this function. + name : str, optional + The name of the deprecated object. + alternative : str, optional + An alternative API that the user may use in place of the + deprecated API. The deprecation warning will tell the user + about this alternative if provided. + pending : bool, optional + If True, uses a PendingDeprecationWarning instead of a + DeprecationWarning. Cannot be used together with removal. + obj_type : str, optional + The object type being deprecated. + addendum : str, optional + Additional text appended directly to the final message. + removal : str, optional + The expected removal version. With the default (an empty + string), a removal version is automatically computed from + since. Set to other Falsy values to not schedule a removal + date. Cannot be used together with pending. + """ + if pending and removal: + raise ValueError("A pending deprecation cannot have a scheduled removal") + + if not pending: + if not removal: + removal = f"in {removal}" if removal else "within ?? minor releases" + raise NotImplementedError( + f"Need to determine which default deprecation schedule to use. " + f"{removal}" + ) + else: + removal = f"in {removal}" + + if not message: + message = "" + + if obj_type: + message += f"The {obj_type} `{name}`" + else: + message += f"`{name}`" + + if pending: + message += " will be deprecated in a future version" + else: + message += f" was deprecated in LangChain {since}" + + if removal: + message += f" and will be removed {removal}" + + if alternative: + message += f". Use {alternative} instead." + + if addendum: + message += f" {addendum}" + + warning_cls = ( + LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning + ) + warning = warning_cls(message) + warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2) + + +def surface_langchain_deprecation_warnings() -> None: + """Unmute LangChain deprecation warnings.""" + warnings.filterwarnings( + "default", + category=LangChainPendingDeprecationWarning, + ) + + warnings.filterwarnings( + "default", + category=LangChainDeprecationWarning, + ) diff --git a/libs/core/langchain_core/_api/path.py b/libs/core/langchain_core/_api/path.py new file mode 100644 index 00000000000..0589ae44956 --- /dev/null +++ b/libs/core/langchain_core/_api/path.py @@ -0,0 +1,36 @@ +import os +from pathlib import Path +from typing import Optional, Union + +HERE = Path(__file__).parent + +# Get directory of langchain package +PACKAGE_DIR = HERE.parent +SEPARATOR = os.sep + + +def get_relative_path( + file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR +) -> str: + """Get the path of the file as a relative path to the package directory.""" + if isinstance(file, str): + file = Path(file) + return str(file.relative_to(relative_to)) + + +def as_import_path( + file: Union[Path, str], + *, + suffix: Optional[str] = None, + relative_to: Path = PACKAGE_DIR, +) -> str: + """Path of the file as a LangChain import exclude langchain top namespace.""" + if isinstance(file, str): + file = Path(file) + path = get_relative_path(file, relative_to=relative_to) + if file.is_file(): + path = path[: -len(file.suffix)] + import_path = path.replace(SEPARATOR, ".") + if suffix: + import_path += "." + suffix + return import_path diff --git a/libs/langchain/tests/unit_tests/_api/__init__.py b/libs/core/langchain_core/callbacks/__init__.py similarity index 100% rename from libs/langchain/tests/unit_tests/_api/__init__.py rename to libs/core/langchain_core/callbacks/__init__.py diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py new file mode 100644 index 00000000000..030e5ecbaf3 --- /dev/null +++ b/libs/core/langchain_core/callbacks/base.py @@ -0,0 +1,598 @@ +"""Base callback handler that can be used to handle callbacks in langchain.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from uuid import UUID + +from tenacity import RetryCallState + +from langchain_core.schema.agent import AgentAction, AgentFinish +from langchain_core.schema.document import Document +from langchain_core.schema.messages import BaseMessage +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult + + +class RetrieverManagerMixin: + """Mixin for Retriever callbacks.""" + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever errors.""" + + def on_retriever_end( + self, + documents: Sequence[Document], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever ends running.""" + + +class LLMManagerMixin: + """Mixin for LLM callbacks.""" + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled. + + Args: + token (str): The new token. + chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, + containing content and other information. + """ + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM ends running.""" + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM errors.""" + + +class ChainManagerMixin: + """Mixin for chain callbacks.""" + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain ends running.""" + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain errors.""" + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent end.""" + + +class ToolManagerMixin: + """Mixin for tool callbacks.""" + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool ends running.""" + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors.""" + + +class CallbackManagerMixin: + """Mixin for callback manager.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM starts running.""" + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever starts running.""" + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when chain starts running.""" + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when tool starts running.""" + + +class RunManagerMixin: + """Mixin for run manager.""" + + def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on arbitrary text.""" + + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + + +class BaseCallbackHandler( + LLMManagerMixin, + ChainManagerMixin, + ToolManagerMixin, + RetrieverManagerMixin, + CallbackManagerMixin, + RunManagerMixin, +): + """Base callback handler that handles callbacks from LangChain.""" + + raise_error: bool = False + + run_inline: bool = False + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return False + + @property + def ignore_retry(self) -> bool: + """Whether to ignore retry callbacks.""" + return False + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return False + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return False + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return False + + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return False + + +class AsyncCallbackHandler(BaseCallbackHandler): + """Async callback handler that handles callbacks from LangChain.""" + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM starts running.""" + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM ends running.""" + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when chain starts running.""" + + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain ends running.""" + + async def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + + async def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" + + async def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + + async def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on agent action.""" + + async def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on agent end.""" + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever start.""" + + async def on_retriever_end( + self, + documents: Sequence[Document], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever end.""" + + async def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on retriever error.""" + + +T = TypeVar("T", bound="BaseCallbackManager") + + +class BaseCallbackManager(CallbackManagerMixin): + """Base callback manager that handles callbacks from LangChain.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + tags: Optional[List[str]] = None, + inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers + self.inheritable_handlers: List[BaseCallbackHandler] = ( + inheritable_handlers or [] + ) + self.parent_run_id: Optional[UUID] = parent_run_id + self.tags = tags or [] + self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} + + def copy(self: T) -> T: + """Copy the callback manager.""" + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @property + def is_async(self) -> bool: + """Whether the callback manager is async.""" + return False + + def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: + """Add a handler to the callback manager.""" + if handler not in self.handlers: + self.handlers.append(handler) + if inherit and handler not in self.inheritable_handlers: + self.inheritable_handlers.append(handler) + + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + self.handlers.remove(handler) + self.inheritable_handlers.remove(handler) + + def set_handlers( + self, handlers: List[BaseCallbackHandler], inherit: bool = True + ) -> None: + """Set handlers as the only handlers on the callback manager.""" + self.handlers = [] + self.inheritable_handlers = [] + for handler in handlers: + self.add_handler(handler, inherit=inherit) + + def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: + """Set handler as the only handler on the callback manager.""" + self.set_handlers([handler], inherit=inherit) + + def add_tags(self, tags: List[str], inherit: bool = True) -> None: + for tag in tags: + if tag in self.tags: + self.remove_tags([tag]) + self.tags.extend(tags) + if inherit: + self.inheritable_tags.extend(tags) + + def remove_tags(self, tags: List[str]) -> None: + for tag in tags: + self.tags.remove(tag) + self.inheritable_tags.remove(tag) + + def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: + self.metadata.update(metadata) + if inherit: + self.inheritable_metadata.update(metadata) + + def remove_metadata(self, keys: List[str]) -> None: + for key in keys: + self.metadata.pop(key) + self.inheritable_metadata.pop(key) + + +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py new file mode 100644 index 00000000000..efd4d5550af --- /dev/null +++ b/libs/core/langchain_core/callbacks/manager.py @@ -0,0 +1,2075 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager, contextmanager +from contextvars import ContextVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Coroutine, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from langsmith import utils as ls_utils +from langsmith.run_helpers import get_run_tree_context +from tenacity import RetryCallState + +from langchain_core.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + Callbacks, + ChainManagerMixin, + LLMManagerMixin, + RetrieverManagerMixin, + RunManagerMixin, + ToolManagerMixin, +) +from langchain_core.callbacks.stdout import StdOutCallbackHandler +from langchain_core.callbacks.tracers import run_collector +from langchain_core.callbacks.tracers.langchain import ( + LangChainTracer, +) +from langchain_core.callbacks.tracers.langchain_v1 import ( + LangChainTracerV1, + TracerSessionV1, +) +from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain_core.schema import ( + AgentAction, + AgentFinish, + Document, + LLMResult, +) +from langchain_core.schema.messages import BaseMessage, get_buffer_string +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk + +if TYPE_CHECKING: + from langsmith import Client as LangSmithClient + +logger = logging.getLogger(__name__) + +tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501 + "tracing_callback", default=None +) + +tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 + "tracing_callback_v2", default=None +) +run_collector_var: ContextVar[ + Optional[run_collector.RunCollectorCallbackHandler] +] = ContextVar( # noqa: E501 + "run_collector", default=None +) + + +def _get_debug() -> bool: + from langchain_core.globals import get_debug + + return get_debug() + + +@contextmanager +def tracing_enabled( + session_name: str = "default", +) -> Generator[TracerSessionV1, None, None]: + """Get the Deprecated LangChainTracer in a context manager. + + Args: + session_name (str, optional): The name of the session. + Defaults to "default". + + Returns: + TracerSessionV1: The LangChainTracer session. + + Example: + >>> with tracing_enabled() as session: + ... # Use the LangChainTracer session + """ + cb = LangChainTracerV1() + session = cast(TracerSessionV1, cb.load_session(session_name)) + try: + tracing_callback_var.set(cb) + yield session + finally: + tracing_callback_var.set(None) + + +@contextmanager +def tracing_v2_enabled( + project_name: Optional[str] = None, + *, + example_id: Optional[Union[str, UUID]] = None, + tags: Optional[List[str]] = None, + client: Optional[LangSmithClient] = None, +) -> Generator[LangChainTracer, None, None]: + """Instruct LangChain to log all runs in context to LangSmith. + + Args: + project_name (str, optional): The name of the project. + Defaults to "default". + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + tags (List[str], optional): The tags to add to the run. + Defaults to None. + + Returns: + None + + Example: + >>> with tracing_v2_enabled(): + ... # LangChain code will automatically be traced + + You can use this to fetch the LangSmith run URL: + + >>> with tracing_v2_enabled() as cb: + ... chain.invoke("foo") + ... run_url = cb.get_run_url() + """ + if isinstance(example_id, str): + example_id = UUID(example_id) + cb = LangChainTracer( + example_id=example_id, + project_name=project_name, + tags=tags, + client=client, + ) + try: + tracing_v2_callback_var.set(cb) + yield cb + finally: + tracing_v2_callback_var.set(None) + + +@contextmanager +def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]: + """Collect all run traces in context. + + Returns: + run_collector.RunCollectorCallbackHandler: The run collector callback handler. + + Example: + >>> with collect_runs() as runs_cb: + chain.invoke("foo") + run_id = runs_cb.traced_runs[0].id + """ + cb = run_collector.RunCollectorCallbackHandler() + run_collector_var.set(cb) + yield cb + run_collector_var.set(None) + + +def _get_trace_callbacks( + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, +) -> Callbacks: + if _tracing_v2_is_enabled(): + project_name_ = project_name or _get_tracer_project() + tracer = tracing_v2_callback_var.get() or LangChainTracer( + project_name=project_name_, + example_id=example_id, + ) + if callback_manager is None: + cb = cast(Callbacks, [tracer]) + else: + if not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(tracer, True) + # If it already has a LangChainTracer, we don't need to add another one. + # this would likely mess up the trace hierarchy. + cb = callback_manager + else: + cb = None + return cb + + +@contextmanager +def trace_as_chain_group( + group_name: str, + callback_manager: Optional[CallbackManager] = None, + *, + inputs: Optional[Dict[str, Any]] = None, + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, +) -> Generator[CallbackManagerForChainGroup, None, None]: + """Get a callback manager for a chain group in a context manager. + Useful for grouping different calls together as a single run even if + they aren't composed in a single chain. + + Args: + group_name (str): The name of the chain group. + callback_manager (CallbackManager, optional): The callback manager to use. + inputs (Dict[str, Any], optional): The inputs to the chain group. + project_name (str, optional): The name of the project. + Defaults to None. + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + run_id (UUID, optional): The ID of the run. + tags (List[str], optional): The inheritable tags to apply to all runs. + Defaults to None. + + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + + Returns: + CallbackManagerForChainGroup: The callback manager for the chain group. + + Example: + .. code-block:: python + + llm_input = "Foo" + with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the callback manager for the chain group + res = llm.predict(llm_input, callbacks=manager) + manager.on_chain_end({"output": res}) + """ # noqa: E501 + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager + ) + cm = CallbackManager.configure( + inheritable_callbacks=cb, + inheritable_tags=tags, + ) + + run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) + child_cm = run_manager.get_child() + group_cm = CallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) + try: + yield group_cm + except Exception as e: + if not group_cm.ended: + run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + run_manager.on_chain_end({}) + + +@asynccontextmanager +async def atrace_as_chain_group( + group_name: str, + callback_manager: Optional[AsyncCallbackManager] = None, + *, + inputs: Optional[Dict[str, Any]] = None, + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, +) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: + """Get an async callback manager for a chain group in a context manager. + Useful for grouping different async calls together as a single run even if + they aren't composed in a single chain. + + Args: + group_name (str): The name of the chain group. + callback_manager (AsyncCallbackManager, optional): The async callback manager to use, + which manages tracing and other callback behavior. + project_name (str, optional): The name of the project. + Defaults to None. + example_id (str or UUID, optional): The ID of the example. + Defaults to None. + run_id (UUID, optional): The ID of the run. + tags (List[str], optional): The inheritable tags to apply to all runs. + Defaults to None. + Returns: + AsyncCallbackManager: The async callback manager for the chain group. + + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + + Example: + .. code-block:: python + + llm_input = "Foo" + async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the async callback manager for the chain group + res = await llm.apredict(llm_input, callbacks=manager) + await manager.on_chain_end({"output": res}) + """ # noqa: E501 + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager + ) + cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) + + run_manager = await cm.on_chain_start( + {"name": group_name}, inputs or {}, run_id=run_id + ) + child_cm = run_manager.get_child() + group_cm = AsyncCallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) + try: + yield group_cm + except Exception as e: + if not group_cm.ended: + await run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + await run_manager.on_chain_end({}) + + +def handle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for CallbackManager. + + Note: This function is used by langserve to handle events. + + Args: + handlers: The list of handlers that will handle the event + event_name: The name of the event (e.g., "on_llm_start") + ignore_condition_name: Name of the attribute defined on handler + that if True will cause the handler to be skipped for the given event + *args: The arguments to pass to the event handler + **kwargs: The keyword arguments to pass to the event handler + """ + coros: List[Coroutine[Any, Any, Any]] = [] + + try: + message_strings: Optional[List[str]] = None + for handler in handlers: + try: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + event = getattr(handler, event_name)(*args, **kwargs) + if asyncio.iscoroutine(event): + coros.append(event) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + if message_strings is None: + message_strings = [get_buffer_string(m) for m in args[1]] + handle_event( + [handler], + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + handler_name = handler.__class__.__name__ + logger.warning( + f"NotImplementedError in {handler_name}.{event_name}" + f" callback: {repr(e)}" + ) + except Exception as e: + logger.warning( + f"Error in {handler.__class__.__name__}.{event_name} callback:" + f" {repr(e)}" + ) + if handler.raise_error: + raise e + finally: + if coros: + try: + # Raises RuntimeError if there is no current event loop. + asyncio.get_running_loop() + loop_running = True + except RuntimeError: + loop_running = False + + if loop_running: + # If we try to submit this coroutine to the running loop + # we end up in a deadlock, as we'd have gotten here from a + # running coroutine, which we cannot interrupt to run this one. + # The solution is to create a new loop in a new thread. + with ThreadPoolExecutor(1) as executor: + executor.submit(_run_coros, coros).result() + else: + _run_coros(coros) + + +def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: + if hasattr(asyncio, "Runner"): + # Python 3.11+ + # Run the coroutines in a new event loop, taking care to + # - install signal handlers + # - run pending tasks scheduled by `coros` + # - close asyncgens and executors + # - close the loop + with asyncio.Runner() as runner: + # Run the coroutine, get the result + for coro in coros: + runner.run(coro) + + # Run pending tasks scheduled by coros until they are all done + while pending := asyncio.all_tasks(runner.get_loop()): + runner.run(asyncio.wait(pending)) + else: + # Before Python 3.11 we need to run each coroutine in a new event loop + # as the Runner api is not available. + for coro in coros: + asyncio.run(coro) + + +async def _ahandle_event_for_handler( + handler: BaseCallbackHandler, + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + try: + if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + event = getattr(handler, event_name) + if asyncio.iscoroutinefunction(event): + await event(*args, **kwargs) + else: + if handler.run_inline: + event(*args, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(event, *args, **kwargs) + ) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + message_strings = [get_buffer_string(m) for m in args[1]] + await _ahandle_event_for_handler( + handler, + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + logger.warning( + f"NotImplementedError in {handler.__class__.__name__}.{event_name}" + f" callback: {repr(e)}" + ) + except Exception as e: + logger.warning( + f"Error in {handler.__class__.__name__}.{event_name} callback:" + f" {repr(e)}" + ) + if handler.raise_error: + raise e + + +async def ahandle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for AsyncCallbackManager. + + Note: This function is used by langserve to handle events. + + Args: + handlers: The list of handlers that will handle the event + event_name: The name of the event (e.g., "on_llm_start") + ignore_condition_name: Name of the attribute defined on handler + that if True will cause the handler to be skipped for the given event + *args: The arguments to pass to the event handler + **kwargs: The keyword arguments to pass to the event handler + """ + for handler in [h for h in handlers if h.run_inline]: + await _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + await asyncio.gather( + *( + _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + for handler in handlers + if not handler.run_inline + ) + ) + + +BRM = TypeVar("BRM", bound="BaseRunManager") + + +class BaseRunManager(RunManagerMixin): + """Base class for run manager (a bound callback manager).""" + + def __init__( + self, + *, + run_id: UUID, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler], + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the run manager. + + Args: + run_id (UUID): The ID of the run. + handlers (List[BaseCallbackHandler]): The list of handlers. + inheritable_handlers (List[BaseCallbackHandler]): + The list of inheritable handlers. + parent_run_id (UUID, optional): The ID of the parent run. + Defaults to None. + tags (Optional[List[str]]): The list of tags. + inheritable_tags (Optional[List[str]]): The list of inheritable tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. + """ + self.run_id = run_id + self.handlers = handlers + self.inheritable_handlers = inheritable_handlers + self.parent_run_id = parent_run_id + self.tags = tags or [] + self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} + + @classmethod + def get_noop_manager(cls: Type[BRM]) -> BRM: + """Return a manager that doesn't perform any operations. + + Returns: + BaseRunManager: The noop manager. + """ + return cls( + run_id=uuid.uuid4(), + handlers=[], + inheritable_handlers=[], + tags=[], + inheritable_tags=[], + metadata={}, + inheritable_metadata={}, + ) + + +class RunManager(BaseRunManager): + """Sync Run Manager.""" + + def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received. + + Args: + text (str): The received text. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + handle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class ParentRunManager(RunManager): + """Sync Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> CallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + CallbackManager: The child callback manager. + """ + manager = CallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + +class AsyncRunManager(BaseRunManager): + """Async Run Manager.""" + + async def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received. + + Args: + text (str): The received text. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + await ahandle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncParentRunManager(AsyncRunManager): + """Async Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + AsyncCallbackManager: The child callback manager. + """ + manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + +class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): + """Callback manager for LLM run.""" + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token. + + Args: + token (str): The new token. + """ + handle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token=token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + chunk=chunk, + **kwargs, + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running. + + Args: + response (LLMResult): The LLM result. + """ + handle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when LLM errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): + """Async callback manager for LLM run.""" + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token. + + Args: + token (str): The new token. + """ + await ahandle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token, + chunk=chunk, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running. + + Args: + response (LLMResult): The LLM result. + """ + await ahandle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when LLM errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): + """Callback manager for chain run.""" + + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + """Run when chain ends running. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + handle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received. + + Args: + action (AgentAction): The agent action. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received. + + Args: + finish (AgentFinish): The agent finish. + + Returns: + Any: The result of the callback. + """ + handle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): + """Async callback manager for chain run.""" + + async def on_chain_end( + self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + ) -> None: + """Run when chain ends running. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + await ahandle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received. + + Args: + action (AgentAction): The agent action. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received. + + Args: + finish (AgentFinish): The agent finish. + + Returns: + Any: The result of the callback. + """ + await ahandle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): + """Callback manager for tool run.""" + + def on_tool_end( + self, + output: str, + **kwargs: Any, + ) -> None: + """Run when tool ends running. + + Args: + output (str): The output of the tool. + """ + handle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_tool_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when tool errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + handle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): + """Async callback manager for tool run.""" + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running. + + Args: + output (str): The output of the tool. + """ + await ahandle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_tool_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when tool errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + await ahandle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): + """Callback manager for retriever run.""" + + def on_retriever_end( + self, + documents: Sequence[Document], + **kwargs: Any, + ) -> None: + """Run when retriever ends running.""" + handle_event( + self.handlers, + "on_retriever_end", + "ignore_retriever", + documents, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + def on_retriever_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when retriever errors.""" + handle_event( + self.handlers, + "on_retriever_error", + "ignore_retriever", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class AsyncCallbackManagerForRetrieverRun( + AsyncParentRunManager, + RetrieverManagerMixin, +): + """Async callback manager for retriever run.""" + + async def on_retriever_end( + self, documents: Sequence[Document], **kwargs: Any + ) -> None: + """Run when retriever ends running.""" + await ahandle_event( + self.handlers, + "on_retriever_end", + "ignore_retriever", + documents, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + async def on_retriever_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when retriever errors.""" + await ahandle_event( + self.handlers, + "on_retriever_error", + "ignore_retriever", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + +class CallbackManager(BaseCallbackManager): + """Callback manager that handles callbacks from LangChain.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> List[CallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The list of prompts. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[CallbackManagerForLLMRun]: A callback manager for each + prompt as an LLM run. + """ + managers = [] + for prompt in prompts: + run_id_ = uuid.uuid4() + handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + return managers + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> List[CallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + messages (List[List[BaseMessage]]): The list of messages. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[CallbackManagerForLLMRun]: A callback manager for each + list of messages as an LLM run. + """ + + managers = [] + for message_list in messages: + run_id_ = uuid.uuid4() + handle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + return managers + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Union[Dict[str, Any], Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForChainRun: + """Run when chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Union[Dict[str, Any], Any]): The inputs to the chain. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + CallbackManagerForChainRun: The callback manager for the chain run. + """ + if run_id is None: + run_id = uuid.uuid4() + handle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForChainRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForToolRun: + """Run when tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input to the tool. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. + + Returns: + CallbackManagerForToolRun: The callback manager for the tool run. + """ + if run_id is None: + run_id = uuid.uuid4() + + handle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForToolRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForRetrieverRun: + """Run when retriever starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + handle_event( + self.handlers, + "on_retriever_start", + "ignore_retriever", + serialized, + query, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return CallbackManagerForRetrieverRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, + ) -> CallbackManager: + """Configure the callback manager. + + Args: + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. + Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + CallbackManager: The configured callback manager. + """ + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + inheritable_metadata, + local_metadata, + ) + + +class CallbackManagerForChainGroup(CallbackManager): + """Callback manager for the chain group.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + parent_run_manager: CallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + def copy(self) -> CallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + return self.parent_run_manager.on_chain_end(outputs, **kwargs) + + def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + return self.parent_run_manager.on_chain_error(error, **kwargs) + + +class AsyncCallbackManager(BaseCallbackManager): + """Async callback manager that handles callbacks from LangChain.""" + + @property + def is_async(self) -> bool: + """Return whether the handler is async.""" + return True + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + **kwargs: Any, + ) -> List[AsyncCallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + prompts (List[str]): The list of prompts. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[AsyncCallbackManagerForLLMRun]: The list of async + callback managers, one for each LLM Run corresponding + to each prompt. + """ + + tasks = [] + managers = [] + + for prompt in prompts: + run_id_ = uuid.uuid4() + + tasks.append( + ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + await asyncio.gather(*tasks) + + return managers + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> List[AsyncCallbackManagerForLLMRun]: + """Run when LLM starts running. + + Args: + serialized (Dict[str, Any]): The serialized LLM. + messages (List[List[BaseMessage]]): The list of messages. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + List[AsyncCallbackManagerForLLMRun]: The list of + async callback managers, one for each LLM Run + corresponding to each inner message list. + """ + tasks = [] + managers = [] + + for message_list in messages: + run_id_ = uuid.uuid4() + + tasks.append( + ahandle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + ) + + await asyncio.gather(*tasks) + return managers + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Union[Dict[str, Any], Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForChainRun: + """Run when chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Union[Dict[str, Any], Any]): The inputs to the chain. + run_id (UUID, optional): The ID of the run. Defaults to None. + + Returns: + AsyncCallbackManagerForChainRun: The async callback manager + for the chain run. + """ + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForChainRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForToolRun: + """Run when tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input to the tool. + run_id (UUID, optional): The ID of the run. Defaults to None. + parent_run_id (UUID, optional): The ID of the parent run. + Defaults to None. + + Returns: + AsyncCallbackManagerForToolRun: The async callback manager + for the tool run. + """ + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForToolRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForRetrieverRun: + """Run when retriever starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + await ahandle_event( + self.handlers, + "on_retriever_start", + "ignore_retriever", + serialized, + query, + run_id=run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + + return AsyncCallbackManagerForRetrieverRun( + run_id=run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, + ) -> AsyncCallbackManager: + """Configure the async callback manager. + + Args: + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. + Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + AsyncCallbackManager: The configured async callback manager. + """ + return _configure( + cls, + inheritable_callbacks, + local_callbacks, + verbose, + inheritable_tags, + local_tags, + inheritable_metadata, + local_metadata, + ) + + +class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): + """Async callback manager for the chain group.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + *, + parent_run_manager: AsyncCallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + def copy(self) -> AsyncCallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + + async def on_chain_end( + self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + ) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + await self.parent_run_manager.on_chain_end(outputs, **kwargs) + + async def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + await self.parent_run_manager.on_chain_error(error, **kwargs) + + +T = TypeVar("T", CallbackManager, AsyncCallbackManager) + + +def env_var_is_set(env_var: str) -> bool: + """Check if an environment variable is set. + + Args: + env_var (str): The name of the environment variable. + + Returns: + bool: True if the environment variable is set, False otherwise. + """ + return env_var in os.environ and os.environ[env_var] not in ( + "", + "0", + "false", + "False", + ) + + +def _tracing_v2_is_enabled() -> bool: + return ( + env_var_is_set("LANGCHAIN_TRACING_V2") + or tracing_v2_callback_var.get() is not None + or get_run_tree_context() is not None + ) + + +def _get_tracer_project() -> str: + run_tree = get_run_tree_context() + return getattr( + run_tree, + "session_name", + getattr( + # Note, if people are trying to nest @traceable functions and the + # tracing_v2_enabled context manager, this will likely mess up the + # tree structure. + tracing_v2_callback_var.get(), + "project", + # Have to set this to a string even though it always will return + # a string because `get_tracer_project` technically can return + # None, but only when a specific argument is supplied. + # Therefore, this just tricks the mypy type checker + str(ls_utils.get_tracer_project()), + ), + ) + + +_configure_hooks: List[ + Tuple[ + ContextVar[Optional[BaseCallbackHandler]], + bool, + Optional[Type[BaseCallbackHandler]], + Optional[str], + ] +] = [] + +H = TypeVar("H", bound=BaseCallbackHandler, covariant=True) + + +def register_configure_hook( + context_var: ContextVar[Optional[Any]], + inheritable: bool, + handle_class: Optional[Type[BaseCallbackHandler]] = None, + env_var: Optional[str] = None, +) -> None: + if env_var is not None and handle_class is None: + raise ValueError( + "If env_var is set, handle_class must also be set to a non-None value." + ) + _configure_hooks.append( + ( + # the typings of ContextVar do not have the generic arg set as covariant + # so we have to cast it + cast(ContextVar[Optional[BaseCallbackHandler]], context_var), + inheritable, + handle_class, + env_var, + ) + ) + + +register_configure_hook(run_collector_var, False) + + +def _configure( + callback_manager_cls: Type[T], + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + inheritable_tags: Optional[List[str]] = None, + local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, +) -> T: + """Configure the callback manager. + + Args: + callback_manager_cls (Type[T]): The callback manager class. + inheritable_callbacks (Optional[Callbacks], optional): The inheritable + callbacks. Defaults to None. + local_callbacks (Optional[Callbacks], optional): The local callbacks. + Defaults to None. + verbose (bool, optional): Whether to enable verbose mode. Defaults to False. + inheritable_tags (Optional[List[str]], optional): The inheritable tags. + Defaults to None. + local_tags (Optional[List[str]], optional): The local tags. Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. + + Returns: + T: The configured callback manager. + """ + run_tree = get_run_tree_context() + parent_run_id = None if run_tree is None else getattr(run_tree, "id") + callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id) + if inheritable_callbacks or local_callbacks: + if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: + inheritable_callbacks_ = inheritable_callbacks or [] + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks_.copy(), + inheritable_handlers=inheritable_callbacks_.copy(), + parent_run_id=parent_run_id, + ) + else: + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks.handlers.copy(), + inheritable_handlers=inheritable_callbacks.inheritable_handlers.copy(), + parent_run_id=inheritable_callbacks.parent_run_id, + tags=inheritable_callbacks.tags.copy(), + inheritable_tags=inheritable_callbacks.inheritable_tags.copy(), + metadata=inheritable_callbacks.metadata.copy(), + inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(), + ) + local_handlers_ = ( + local_callbacks + if isinstance(local_callbacks, list) + else (local_callbacks.handlers if local_callbacks else []) + ) + for handler in local_handlers_: + callback_manager.add_handler(handler, False) + if inheritable_tags or local_tags: + callback_manager.add_tags(inheritable_tags or []) + callback_manager.add_tags(local_tags or [], False) + if inheritable_metadata or local_metadata: + callback_manager.add_metadata(inheritable_metadata or {}) + callback_manager.add_metadata(local_metadata or {}, False) + + tracer = tracing_callback_var.get() + tracing_enabled_ = ( + env_var_is_set("LANGCHAIN_TRACING") + or tracer is not None + or env_var_is_set("LANGCHAIN_HANDLER") + ) + + tracer_v2 = tracing_v2_callback_var.get() + tracing_v2_enabled_ = _tracing_v2_is_enabled() + tracer_project = _get_tracer_project() + debug = _get_debug() + if verbose or debug or tracing_enabled_ or tracing_v2_enabled_: + if verbose and not any( + isinstance(handler, StdOutCallbackHandler) + for handler in callback_manager.handlers + ): + if debug: + pass + else: + callback_manager.add_handler(StdOutCallbackHandler(), False) + if debug and not any( + isinstance(handler, ConsoleCallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(ConsoleCallbackHandler(), True) + if tracing_enabled_ and not any( + isinstance(handler, LangChainTracerV1) + for handler in callback_manager.handlers + ): + if tracer: + callback_manager.add_handler(tracer, True) + else: + handler = LangChainTracerV1() + handler.load_session(tracer_project) + callback_manager.add_handler(handler, True) + if tracing_v2_enabled_ and not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + if tracer_v2: + callback_manager.add_handler(tracer_v2, True) + else: + try: + handler = LangChainTracer(project_name=tracer_project) + callback_manager.add_handler(handler, True) + except Exception as e: + logger.warning( + "Unable to load requested LangChainTracer." + " To disable this warning," + " unset the LANGCHAIN_TRACING_V2 environment variables.", + e, + ) + for var, inheritable, handler_class, env_var in _configure_hooks: + create_one = ( + env_var is not None + and env_var_is_set(env_var) + and handler_class is not None + ) + if var.get() is not None or create_one: + var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)() + if handler_class is None: + if not any( + handler is var_handler # direct pointer comparison + for handler in callback_manager.handlers + ): + callback_manager.add_handler(var_handler, inheritable) + else: + if not any( + isinstance(handler, handler_class) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(var_handler, inheritable) + return callback_manager diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py new file mode 100644 index 00000000000..85b61ec40ee --- /dev/null +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -0,0 +1,97 @@ +"""Callback Handler that prints to std out.""" +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.schema import AgentAction, AgentFinish, LLMResult +from langchain_core.utils.input import print_text + + +class StdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, color: Optional[str] = None) -> None: + """Initialize callback handler.""" + self.color = color + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + print("\n\033[1m> Finished chain.\033[0m") + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing.""" + pass + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + """Run on agent action.""" + print_text(action.log, color=color or self.color) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + if observation_prefix is not None: + print_text(f"\n{observation_prefix}") + print_text(output, color=color or self.color) + if llm_prefix is not None: + print_text(f"\n{llm_prefix}") + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Any, + ) -> None: + """Run when agent ends.""" + print_text(text, color=color or self.color, end=end) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text(finish.log, color=color or self.color, end="\n") diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py new file mode 100644 index 00000000000..a678e836206 --- /dev/null +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -0,0 +1,67 @@ +"""Callback Handler streams to stdout on new llm token.""" +import sys +from typing import Any, Dict, List + +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.schema import AgentAction, AgentFinish, LLMResult +from langchain_core.schema.messages import BaseMessage + + +class StreamingStdOutCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming. Only works with LLMs that support streaming.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> None: + """Run when LLM starts running.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + sys.stdout.write(token) + sys.stdout.flush() + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when LLM errors.""" + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when chain errors.""" + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + pass + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Run when tool errors.""" + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__init__.py b/libs/core/langchain_core/callbacks/tracers/__init__.py similarity index 100% rename from libs/langchain/tests/unit_tests/schema/runnable/__init__.py rename to libs/core/langchain_core/callbacks/tracers/__init__.py diff --git a/libs/core/langchain_core/callbacks/tracers/base.py b/libs/core/langchain_core/callbacks/tracers/base.py new file mode 100644 index 00000000000..1b3ba409273 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/base.py @@ -0,0 +1,537 @@ +"""Base interfaces for tracing runs.""" +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Union, cast +from uuid import UUID + +from tenacity import RetryCallState + +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.load.dump import dumpd +from langchain_core.schema.document import Document +from langchain_core.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + GenerationChunk, + LLMResult, +) + +logger = logging.getLogger(__name__) + + +class TracerException(Exception): + """Base class for exceptions in tracers module.""" + + +class BaseTracer(BaseCallbackHandler, ABC): + """Base interface for tracers.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.run_map: Dict[str, Run] = {} + + @staticmethod + def _add_child_run( + parent_run: Run, + child_run: Run, + ) -> None: + """Add child run to a chain run or tool run.""" + parent_run.child_runs.append(child_run) + + @abstractmethod + def _persist_run(self, run: Run) -> None: + """Persist a run.""" + + def _start_trace(self, run: Run) -> None: + """Start a trace for a run.""" + if run.parent_run_id: + parent_run = self.run_map.get(str(run.parent_run_id)) + if parent_run: + self._add_child_run(parent_run, run) + parent_run.child_execution_order = max( + parent_run.child_execution_order, run.child_execution_order + ) + else: + logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") + self.run_map[str(run.id)] = run + self._on_run_create(run) + + def _end_trace(self, run: Run) -> None: + """End a trace for a run.""" + if not run.parent_run_id: + self._persist_run(run) + else: + parent_run = self.run_map.get(str(run.parent_run_id)) + if parent_run is None: + logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") + elif ( + run.child_execution_order is not None + and parent_run.child_execution_order is not None + and run.child_execution_order > parent_run.child_execution_order + ): + parent_run.child_execution_order = run.child_execution_order + self.run_map.pop(str(run.id)) + self._on_run_update(run) + + def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: + """Get the execution order for a run.""" + if parent_run_id is None: + return 1 + + parent_run = self.run_map.get(parent_run_id) + if parent_run is None: + logger.debug(f"Parent run with UUID {parent_run_id} not found.") + return 1 + if parent_run.child_execution_order is None: + raise TracerException( + f"Parent run with UUID {parent_run_id} has no child execution order." + ) + + return parent_run.child_execution_order + 1 + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for an LLM run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + llm_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"prompts": prompts}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + run_type="llm", + tags=tags or [], + name=name, + ) + self._start_trace(llm_run) + self._on_llm_start(llm_run) + return llm_run + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Run: + """Run on new LLM token. Only available when streaming is enabled.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_new_token callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + event_kwargs: Dict[str, Any] = {"token": token} + if chunk: + event_kwargs["chunk"] = chunk + llm_run.events.append( + { + "name": "new_token", + "time": datetime.utcnow(), + "kwargs": event_kwargs, + }, + ) + self._on_llm_new_token(llm_run, token, chunk) + return llm_run + + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + if not run_id: + raise TracerException("No run_id provided for on_retry callback.") + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None: + raise TracerException("No Run found to be traced for on_retry") + retry_d: Dict[str, Any] = { + "slept": retry_state.idle_for, + "attempt": retry_state.attempt_number, + } + if retry_state.outcome is None: + retry_d["outcome"] = "N/A" + elif retry_state.outcome.failed: + retry_d["outcome"] = "failed" + exception = retry_state.outcome.exception() + retry_d["exception"] = str(exception) + retry_d["exception_type"] = exception.__class__.__name__ + else: + retry_d["outcome"] = "success" + retry_d["result"] = str(retry_state.outcome.result()) + llm_run.events.append( + { + "name": "retry", + "time": datetime.utcnow(), + "kwargs": retry_d, + }, + ) + return llm_run + + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: + """End a trace for an LLM run.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_end callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run.outputs = response.dict() + for i, generations in enumerate(response.generations): + for j, generation in enumerate(generations): + output_generation = llm_run.outputs["generations"][i][j] + if "message" in output_generation: + output_generation["message"] = dumpd( + cast(ChatGeneration, generation).message + ) + llm_run.end_time = datetime.utcnow() + llm_run.events.append({"name": "end", "time": llm_run.end_time}) + self._end_trace(llm_run) + self._on_llm_end(llm_run) + return llm_run + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for an LLM run.""" + if not run_id: + raise TracerException("No run_id provided for on_llm_error callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != "llm": + raise TracerException(f"No LLM Run found to be traced for {run_id}") + llm_run.error = repr(error) + llm_run.end_time = datetime.utcnow() + llm_run.events.append({"name": "error", "time": llm_run.end_time}) + self._end_trace(llm_run) + self._on_chain_error(llm_run) + return llm_run + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for a chain run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + chain_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs=inputs if isinstance(inputs, dict) else {"input": inputs}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + child_runs=[], + run_type=run_type or "chain", + name=name, + tags=tags or [], + ) + self._start_trace(chain_run) + self._on_chain_start(chain_run) + return chain_run + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Run: + """End a trace for a chain run.""" + if not run_id: + raise TracerException("No run_id provided for on_chain_end callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None: + raise TracerException(f"No chain Run found to be traced for {run_id}") + + chain_run.outputs = ( + outputs if isinstance(outputs, dict) else {"output": outputs} + ) + chain_run.end_time = datetime.utcnow() + chain_run.events.append({"name": "end", "time": chain_run.end_time}) + if inputs is not None: + chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} + self._end_trace(chain_run) + self._on_chain_end(chain_run) + return chain_run + + def on_chain_error( + self, + error: BaseException, + *, + inputs: Optional[Dict[str, Any]] = None, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for a chain run.""" + if not run_id: + raise TracerException("No run_id provided for on_chain_error callback.") + chain_run = self.run_map.get(str(run_id)) + if chain_run is None: + raise TracerException(f"No chain Run found to be traced for {run_id}") + + chain_run.error = repr(error) + chain_run.end_time = datetime.utcnow() + chain_run.events.append({"name": "error", "time": chain_run.end_time}) + if inputs is not None: + chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} + self._end_trace(chain_run) + self._on_chain_error(chain_run) + return chain_run + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Start a trace for a tool run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + tool_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"input": input_str}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + child_runs=[], + run_type="tool", + tags=tags or [], + name=name, + ) + self._start_trace(tool_run) + self._on_tool_start(tool_run) + return tool_run + + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: + """End a trace for a tool run.""" + if not run_id: + raise TracerException("No run_id provided for on_tool_end callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != "tool": + raise TracerException(f"No tool Run found to be traced for {run_id}") + + tool_run.outputs = {"output": output} + tool_run.end_time = datetime.utcnow() + tool_run.events.append({"name": "end", "time": tool_run.end_time}) + self._end_trace(tool_run) + self._on_tool_end(tool_run) + return tool_run + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Handle an error for a tool run.""" + if not run_id: + raise TracerException("No run_id provided for on_tool_error callback.") + tool_run = self.run_map.get(str(run_id)) + if tool_run is None or tool_run.run_type != "tool": + raise TracerException(f"No tool Run found to be traced for {run_id}") + + tool_run.error = repr(error) + tool_run.end_time = datetime.utcnow() + tool_run.events.append({"name": "error", "time": tool_run.end_time}) + self._end_trace(tool_run) + self._on_tool_error(tool_run) + return tool_run + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + """Run when Retriever starts running.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + retrieval_run = Run( + id=run_id, + name=name or "Retriever", + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"query": query}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + tags=tags, + child_runs=[], + run_type="retriever", + ) + self._start_trace(retrieval_run) + self._on_retriever_start(retrieval_run) + return retrieval_run + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + **kwargs: Any, + ) -> Run: + """Run when Retriever errors.""" + if not run_id: + raise TracerException("No run_id provided for on_retriever_error callback.") + retrieval_run = self.run_map.get(str(run_id)) + if retrieval_run is None or retrieval_run.run_type != "retriever": + raise TracerException(f"No retriever Run found to be traced for {run_id}") + + retrieval_run.error = repr(error) + retrieval_run.end_time = datetime.utcnow() + retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) + self._end_trace(retrieval_run) + self._on_retriever_error(retrieval_run) + return retrieval_run + + def on_retriever_end( + self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any + ) -> Run: + """Run when Retriever ends running.""" + if not run_id: + raise TracerException("No run_id provided for on_retriever_end callback.") + retrieval_run = self.run_map.get(str(run_id)) + if retrieval_run is None or retrieval_run.run_type != "retriever": + raise TracerException(f"No retriever Run found to be traced for {run_id}") + retrieval_run.outputs = {"documents": documents} + retrieval_run.end_time = datetime.utcnow() + retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) + self._end_trace(retrieval_run) + self._on_retriever_end(retrieval_run) + return retrieval_run + + def __deepcopy__(self, memo: dict) -> BaseTracer: + """Deepcopy the tracer.""" + return self + + def __copy__(self) -> BaseTracer: + """Copy the tracer.""" + return self + + def _on_run_create(self, run: Run) -> None: + """Process a run upon creation.""" + + def _on_run_update(self, run: Run) -> None: + """Process a run upon update.""" + + def _on_llm_start(self, run: Run) -> None: + """Process the LLM Run upon start.""" + + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + + def _on_chat_model_start(self, run: Run) -> None: + """Process the Chat Model Run upon start.""" + + def _on_retriever_start(self, run: Run) -> None: + """Process the Retriever Run upon start.""" + + def _on_retriever_end(self, run: Run) -> None: + """Process the Retriever Run.""" + + def _on_retriever_error(self, run: Run) -> None: + """Process the Retriever Run upon error.""" diff --git a/libs/core/langchain_core/callbacks/tracers/evaluation.py b/libs/core/langchain_core/callbacks/tracers/evaluation.py new file mode 100644 index 00000000000..fa0f62e8879 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/evaluation.py @@ -0,0 +1,223 @@ +"""A tracer that runs evaluators over completed runs.""" +from __future__ import annotations + +import logging +import threading +import weakref +from concurrent.futures import Future, ThreadPoolExecutor, wait +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from uuid import UUID + +import langsmith +from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults + +from langchain_core.callbacks import manager +from langchain_core.callbacks.tracers import langchain as langchain_tracer +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.langchain import _get_executor +from langchain_core.callbacks.tracers.schemas import Run + +logger = logging.getLogger(__name__) + +_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() + + +def wait_for_all_evaluators() -> None: + """Wait for all tracers to finish.""" + global _TRACERS + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() + + +class EvaluatorCallbackHandler(BaseTracer): + """A tracer that runs a run evaluator whenever a run is persisted. + + Parameters + ---------- + evaluators : Sequence[RunEvaluator] + The run evaluators to apply to all top level runs. + client : LangSmith Client, optional + The LangSmith client instance to use for evaluating the runs. + If not specified, a new instance will be created. + example_id : Union[UUID, str], optional + The example ID to be associated with the runs. + project_name : str, optional + The LangSmith project name to be organize eval chain runs under. + + Attributes + ---------- + example_id : Union[UUID, None] + The example ID associated with the runs. + client : Client + The LangSmith client instance used for evaluating the runs. + evaluators : Sequence[RunEvaluator] + The sequence of run evaluators to be executed. + executor : ThreadPoolExecutor + The thread pool executor used for running the evaluators. + futures : Set[Future] + The set of futures representing the running evaluators. + skip_unfinished : bool + Whether to skip runs that are not finished or raised + an error. + project_name : Optional[str] + The LangSmith project name to be organize eval chain runs under. + """ + + name = "evaluator_callback_handler" + + def __init__( + self, + evaluators: Sequence[langsmith.RunEvaluator], + client: Optional[langsmith.Client] = None, + example_id: Optional[Union[UUID, str]] = None, + skip_unfinished: bool = True, + project_name: Optional[str] = "evaluators", + max_concurrency: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.client = client or langchain_tracer.get_client() + self.evaluators = evaluators + if max_concurrency is None: + self.executor: Optional[ThreadPoolExecutor] = _get_executor() + elif max_concurrency > 0: + self.executor = ThreadPoolExecutor(max_workers=max_concurrency) + weakref.finalize( + self, + lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True), + ) + else: + self.executor = None + self.futures: weakref.WeakSet[Future] = weakref.WeakSet() + self.skip_unfinished = skip_unfinished + self.project_name = project_name + self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} + self.lock = threading.Lock() + global _TRACERS + _TRACERS.add(self) + + def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None: + """Evaluate the run in the project. + + Parameters + ---------- + run : Run + The run to be evaluated. + evaluator : RunEvaluator + The evaluator to use for evaluating the run. + + """ + try: + if self.project_name is None: + eval_result = self.client.evaluate_run(run, evaluator) + eval_results = [eval_result] + with manager.tracing_v2_enabled( + project_name=self.project_name, tags=["eval"], client=self.client + ) as cb: + reference_example = ( + self.client.read_example(run.reference_example_id) + if run.reference_example_id + else None + ) + evaluation_result = evaluator.evaluate_run( + # This is subclass, but getting errors for some reason + run, # type: ignore + example=reference_example, + ) + eval_results = self._log_evaluation_feedback( + evaluation_result, + run, + source_run_id=cb.latest_run.id if cb.latest_run else None, + ) + except Exception as e: + logger.error( + f"Error evaluating run {run.id} with " + f"{evaluator.__class__.__name__}: {repr(e)}", + exc_info=True, + ) + raise e + example_id = str(run.reference_example_id) + with self.lock: + for res in eval_results: + run_id = ( + str(getattr(res, "target_run_id")) + if hasattr(res, "target_run_id") + else str(run.id) + ) + self.logged_eval_results.setdefault((run_id, example_id), []).append( + res + ) + + def _select_eval_results( + self, + results: Union[EvaluationResult, EvaluationResults], + ) -> List[EvaluationResult]: + if isinstance(results, EvaluationResult): + results_ = [results] + elif isinstance(results, dict) and "results" in results: + results_ = cast(List[EvaluationResult], results["results"]) + else: + raise TypeError( + f"Invalid evaluation result type {type(results)}." + " Expected EvaluationResult or EvaluationResults." + ) + return results_ + + def _log_evaluation_feedback( + self, + evaluator_response: Union[EvaluationResult, EvaluationResults], + run: Run, + source_run_id: Optional[UUID] = None, + ) -> List[EvaluationResult]: + results = self._select_eval_results(evaluator_response) + for res in results: + source_info_: Dict[str, Any] = {} + if res.evaluator_info: + source_info_ = {**res.evaluator_info, **source_info_} + run_id_ = ( + getattr(res, "target_run_id") + if hasattr(res, "target_run_id") and res.target_run_id is not None + else run.id + ) + self.client.create_feedback( + run_id_, + res.key, + score=res.score, + value=res.value, + comment=res.comment, + correction=res.correction, + source_info=source_info_, + source_run_id=res.source_run_id or source_run_id, + feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL, + ) + return results + + def _persist_run(self, run: Run) -> None: + """Run the evaluator on the run. + + Parameters + ---------- + run : Run + The run to be evaluated. + + """ + if self.skip_unfinished and not run.outputs: + logger.debug(f"Skipping unfinished run {run.id}") + return + run_ = run.copy() + run_.reference_example_id = self.example_id + for evaluator in self.evaluators: + if self.executor is None: + self._evaluate_in_project(run_, evaluator) + else: + self.futures.add( + self.executor.submit(self._evaluate_in_project, run_, evaluator) + ) + + def wait_for_futures(self) -> None: + """Wait for all futures to complete.""" + wait(self.futures) diff --git a/libs/core/langchain_core/callbacks/tracers/langchain.py b/libs/core/langchain_core/callbacks/tracers/langchain.py new file mode 100644 index 00000000000..7ab7f44a1f6 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/langchain.py @@ -0,0 +1,262 @@ +"""A Tracer implementation that records to LangChain endpoint.""" +from __future__ import annotations + +import logging +import weakref +from concurrent.futures import Future, ThreadPoolExecutor, wait +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Union +from uuid import UUID + +from langsmith import Client +from langsmith import utils as ls_utils +from tenacity import ( + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.env import get_runtime_environment +from langchain_core.load.dump import dumpd +from langchain_core.schema.messages import BaseMessage + +logger = logging.getLogger(__name__) +_LOGGED = set() +_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() +_CLIENT: Optional[Client] = None +_EXECUTOR: Optional[ThreadPoolExecutor] = None + + +def log_error_once(method: str, exception: Exception) -> None: + """Log an error once.""" + global _LOGGED + if (method, type(exception)) in _LOGGED: + return + _LOGGED.add((method, type(exception))) + logger.error(exception) + + +def wait_for_all_tracers() -> None: + """Wait for all tracers to finish.""" + global _TRACERS + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() + + +def get_client() -> Client: + """Get the client.""" + global _CLIENT + if _CLIENT is None: + _CLIENT = Client() + return _CLIENT + + +def _get_executor() -> ThreadPoolExecutor: + """Get the executor.""" + global _EXECUTOR + if _EXECUTOR is None: + _EXECUTOR = ThreadPoolExecutor() + return _EXECUTOR + + +def _copy(run: Run) -> Run: + """Copy a run.""" + try: + return run.copy(deep=True) + except TypeError: + # Fallback in case the object contains a lock or other + # non-pickleable object + return run.copy() + + +class LangChainTracer(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__( + self, + example_id: Optional[Union[UUID, str]] = None, + project_name: Optional[str] = None, + client: Optional[Client] = None, + tags: Optional[List[str]] = None, + use_threading: bool = True, + **kwargs: Any, + ) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.project_name = project_name or ls_utils.get_tracer_project() + self.client = client or get_client() + self._futures: weakref.WeakSet[Future] = weakref.WeakSet() + self.tags = tags or [] + self.executor = _get_executor() if use_threading else None + self.latest_run: Optional[Run] = None + global _TRACERS + _TRACERS.add(self) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + parent_run_id_ = str(parent_run_id) if parent_run_id else None + execution_order = self._get_execution_order(parent_run_id_) + start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) + chat_model_run = Run( + id=run_id, + parent_run_id=parent_run_id, + serialized=serialized, + inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, + extra=kwargs, + events=[{"name": "start", "time": start_time}], + start_time=start_time, + execution_order=execution_order, + child_execution_order=execution_order, + run_type="llm", + tags=tags, + name=name, + ) + self._start_trace(chat_model_run) + self._on_chat_model_start(chat_model_run) + + def _persist_run(self, run: Run) -> None: + run_ = run.copy() + run_.reference_example_id = self.example_id + self.latest_run = run_ + + def get_run_url(self) -> str: + """Get the LangSmith root run URL""" + if not self.latest_run: + raise ValueError("No traced run found.") + # If this is the first run in a project, the project may not yet be created. + # This method is only really useful for debugging flows, so we will assume + # there is some tolerace for latency. + for attempt in Retrying( + stop=stop_after_attempt(5), + wait=wait_exponential_jitter(), + retry=retry_if_exception_type(ls_utils.LangSmithError), + ): + with attempt: + return self.client.get_run_url( + run=self.latest_run, project_name=self.project_name + ) + raise ValueError("Failed to get run URL.") + + def _get_tags(self, run: Run) -> List[str]: + """Get combined tags for a run.""" + tags = set(run.tags or []) + tags.update(self.tags or []) + return list(tags) + + def _persist_run_single(self, run: Run) -> None: + """Persist a run.""" + run_dict = run.dict(exclude={"child_runs"}) + run_dict["tags"] = self._get_tags(run) + extra = run_dict.get("extra", {}) + extra["runtime"] = get_runtime_environment() + run_dict["extra"] = extra + try: + self.client.create_run(**run_dict, project_name=self.project_name) + except Exception as e: + # Errors are swallowed by the thread executor so we need to log them here + log_error_once("post", e) + raise + + def _update_run_single(self, run: Run) -> None: + """Update a run.""" + try: + run_dict = run.dict() + run_dict["tags"] = self._get_tags(run) + self.client.update_run(run.id, **run_dict) + except Exception as e: + # Errors are swallowed by the thread executor so we need to log them here + log_error_once("patch", e) + raise + + def _submit(self, function: Callable[[Run], None], run: Run) -> None: + """Submit a function to the executor.""" + if self.executor is None: + function(run) + else: + self._futures.add(self.executor.submit(function, run)) + + def _on_llm_start(self, run: Run) -> None: + """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_chat_model_start(self, run: Run) -> None: + """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_llm_end(self, run: Run) -> None: + """Process the LLM Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_llm_error(self, run: Run) -> None: + """Process the LLM Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_chain_start(self, run: Run) -> None: + """Process the Chain Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_chain_end(self, run: Run) -> None: + """Process the Chain Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_chain_error(self, run: Run) -> None: + """Process the Chain Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_tool_start(self, run: Run) -> None: + """Process the Tool Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_tool_end(self, run: Run) -> None: + """Process the Tool Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_tool_error(self, run: Run) -> None: + """Process the Tool Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_retriever_start(self, run: Run) -> None: + """Process the Retriever Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id + self._submit(self._persist_run_single, _copy(run)) + + def _on_retriever_end(self, run: Run) -> None: + """Process the Retriever Run.""" + self._submit(self._update_run_single, _copy(run)) + + def _on_retriever_error(self, run: Run) -> None: + """Process the Retriever Run upon error.""" + self._submit(self._update_run_single, _copy(run)) + + def wait_for_futures(self) -> None: + """Wait for the given futures to complete.""" + wait(self._futures) diff --git a/libs/core/langchain_core/callbacks/tracers/langchain_v1.py b/libs/core/langchain_core/callbacks/tracers/langchain_v1.py new file mode 100644 index 00000000000..733a8dcf250 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/langchain_v1.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, Optional, Union + +import requests + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import ( + ChainRun, + LLMRun, + Run, + ToolRun, + TracerSession, + TracerSessionV1, + TracerSessionV1Base, +) +from langchain_core.schema.messages import get_buffer_string +from langchain_core.utils import raise_for_status_with_text + +logger = logging.getLogger(__name__) + + +def get_headers() -> Dict[str, Any]: + """Get the headers for the LangChain API.""" + headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + return headers + + +def _get_endpoint() -> str: + return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + + +class LangChainTracerV1(BaseTracer): + """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.session: Optional[TracerSessionV1] = None + self._endpoint = _get_endpoint() + self._headers = get_headers() + + def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: + session = self.session or self.load_default_session() + if not isinstance(session, TracerSessionV1): + raise ValueError( + "LangChainTracerV1 is not compatible with" + f" session of type {type(session)}" + ) + + if run.run_type == "llm": + if "prompts" in run.inputs: + prompts = run.inputs["prompts"] + elif "messages" in run.inputs: + prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] + else: + raise ValueError("No prompts found in LLM run inputs") + return LLMRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + extra=run.extra, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + error=run.error, + prompts=prompts, + response=run.outputs if run.outputs else None, + ) + if run.run_type == "chain": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ChainRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + inputs=run.inputs, + outputs=run.outputs, + error=run.error, + extra=run.extra, + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + ) + if run.run_type == "tool": + child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] + return ToolRun( + uuid=str(run.id) if run.id else None, + parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, + start_time=run.start_time, + end_time=run.end_time, + execution_order=run.execution_order, + child_execution_order=run.child_execution_order, + serialized=run.serialized, + session_id=session.id, + action=str(run.serialized), + tool_input=run.inputs.get("input", ""), + output=None if run.outputs is None else run.outputs.get("output"), + error=run.error, + extra=run.extra, + child_chain_runs=[ + run for run in child_runs if isinstance(run, ChainRun) + ], + child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], + child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], + ) + raise ValueError(f"Unknown run type: {run.run_type}") + + def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: + """Persist a run.""" + if isinstance(run, Run): + v1_run = self._convert_to_v1_run(run) + else: + v1_run = run + if isinstance(v1_run, LLMRun): + endpoint = f"{self._endpoint}/llm-runs" + elif isinstance(v1_run, ChainRun): + endpoint = f"{self._endpoint}/chain-runs" + else: + endpoint = f"{self._endpoint}/tool-runs" + + try: + response = requests.post( + endpoint, + data=v1_run.json(), + headers=self._headers, + ) + raise_for_status_with_text(response) + except Exception as e: + logger.warning(f"Failed to persist run: {e}") + + def _persist_session( + self, session_create: TracerSessionV1Base + ) -> Union[TracerSessionV1, TracerSession]: + """Persist a session.""" + try: + r = requests.post( + f"{self._endpoint}/sessions", + data=session_create.json(), + headers=self._headers, + ) + session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) + except Exception as e: + logger.warning(f"Failed to create session, using default session: {e}") + session = TracerSessionV1(id=1, **session_create.dict()) + return session + + def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: + """Load a session from the tracer.""" + try: + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + + tracer_session = TracerSessionV1(**r.json()[0]) + except Exception as e: + session_type = "default" if not session_name else session_name + logger.warning( + f"Failed to load {session_type} session, using empty session: {e}" + ) + tracer_session = TracerSessionV1(id=1) + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) + + def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: + """Load the default tracing session and set it as the Tracer's session.""" + return self._load_session("default") diff --git a/libs/core/langchain_core/callbacks/tracers/log_stream.py b/libs/core/langchain_core/callbacks/tracers/log_stream.py new file mode 100644 index 00000000000..53c28e4d06c --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/log_stream.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import math +import threading +from collections import defaultdict +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + TypedDict, + Union, +) +from uuid import UUID + +import jsonpatch +from anyio import create_memory_object_stream + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.load.load import load +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk + + +class LogEntry(TypedDict): + """A single entry in the run log.""" + + id: str + """ID of the sub-run.""" + name: str + """Name of the object being run.""" + type: str + """Type of the object being run, eg. prompt, chain, llm, etc.""" + tags: List[str] + """List of tags for the run.""" + metadata: Dict[str, Any] + """Key-value pairs of metadata for the run.""" + start_time: str + """ISO-8601 timestamp of when the run started.""" + + streamed_output_str: List[str] + """List of LLM tokens streamed by this run, if applicable.""" + final_output: Optional[Any] + """Final output of this run. + Only available after the run has finished successfully.""" + end_time: Optional[str] + """ISO-8601 timestamp of when the run ended. + Only available after the run has finished.""" + + +class RunState(TypedDict): + """State of the run.""" + + id: str + """ID of the run.""" + streamed_output: List[Any] + """List of output chunks streamed by Runnable.stream()""" + final_output: Optional[Any] + """Final output of the run, usually the result of aggregating (`+`) streamed_output. + Only available after the run has finished successfully.""" + + logs: Dict[str, LogEntry] + """Map of run names to sub-runs. If filters were supplied, this list will + contain only the runs that matched the filters.""" + + +class RunLogPatch: + """A patch to the run log.""" + + ops: List[Dict[str, Any]] + """List of jsonpatch operations, which describe how to create the run state + from an empty dict. This is the minimal representation of the log, designed to + be serialized as JSON and sent over the wire to reconstruct the log on the other + side. Reconstruction of the state can be done with any jsonpatch-compliant library, + see https://jsonpatch.com for more information.""" + + def __init__(self, *ops: Dict[str, Any]) -> None: + self.ops = list(ops) + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(None, ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + # 1:-1 to get rid of the [] around the list + return f"RunLogPatch({pformat(self.ops)[1:-1]})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, RunLogPatch) and self.ops == other.ops + + +class RunLog(RunLogPatch): + """A run log.""" + + state: RunState + """Current state of the log, obtained from applying all ops in sequence.""" + + def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: + super().__init__(*ops) + self.state = state + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(self.state, other.ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + return f"RunLog({pformat(self.state)})" + + +class LogStreamCallbackHandler(BaseTracer): + """A tracer that streams run logs to a stream.""" + + def __init__( + self, + *, + auto_close: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + ) -> None: + super().__init__() + + self.auto_close = auto_close + self.include_names = include_names + self.include_types = include_types + self.include_tags = include_tags + self.exclude_names = exclude_names + self.exclude_types = exclude_types + self.exclude_tags = exclude_tags + + send_stream: Any + receive_stream: Any + send_stream, receive_stream = create_memory_object_stream( + math.inf, item_type=RunLogPatch + ) + self.lock = threading.Lock() + self.send_stream = send_stream + self.receive_stream = receive_stream + self._key_map_by_run_id: Dict[UUID, str] = {} + self._counter_map_by_name: Dict[str, int] = defaultdict(int) + self.root_id: Optional[UUID] = None + + def __aiter__(self) -> AsyncIterator[RunLogPatch]: + return self.receive_stream.__aiter__() + + def include_run(self, run: Run) -> bool: + if run.id == self.root_id: + return False + + run_tags = run.tags or [] + + if ( + self.include_names is None + and self.include_types is None + and self.include_tags is None + ): + include = True + else: + include = False + + if self.include_names is not None: + include = include or run.name in self.include_names + if self.include_types is not None: + include = include or run.run_type in self.include_types + if self.include_tags is not None: + include = include or any(tag in self.include_tags for tag in run_tags) + + if self.exclude_names is not None: + include = include and run.name not in self.exclude_names + if self.exclude_types is not None: + include = include and run.run_type not in self.exclude_types + if self.exclude_tags is not None: + include = include and all(tag not in self.exclude_tags for tag in run_tags) + + return include + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + """Start a run.""" + if self.root_id is None: + self.root_id = run.id + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "", + "value": RunState( + id=str(run.id), + streamed_output=[], + final_output=None, + logs={}, + ), + } + ) + ) + + if not self.include_run(run): + return + + # Determine previous index, increment by 1 + with self.lock: + self._counter_map_by_name[run.name] += 1 + count = self._counter_map_by_name[run.name] + self._key_map_by_run_id[run.id] = ( + run.name if count == 1 else f"{run.name}:{count}" + ) + + # Add the run to the stream + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{self._key_map_by_run_id[run.id]}", + "value": LogEntry( + id=str(run.id), + name=run.name, + type=run.run_type, + tags=run.tags or [], + metadata=(run.extra or {}).get("metadata", {}), + start_time=run.start_time.isoformat(timespec="milliseconds"), + streamed_output_str=[], + final_output=None, + end_time=None, + ), + } + ) + ) + + def _on_run_update(self, run: Run) -> None: + """Finish a run.""" + try: + index = self._key_map_by_run_id.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/final_output", + # to undo the dumpd done by some runnables / tracer / etc + "value": load(run.outputs), + }, + { + "op": "add", + "path": f"/logs/{index}/end_time", + "value": run.end_time.isoformat(timespec="milliseconds") + if run.end_time is not None + else None, + }, + ) + ) + finally: + if run.id == self.root_id: + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "/final_output", + "value": load(run.outputs), + } + ) + ) + if self.auto_close: + self.send_stream.close() + + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + index = self._key_map_by_run_id.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/streamed_output_str/-", + "value": token, + } + ) + ) diff --git a/libs/core/langchain_core/callbacks/tracers/root_listeners.py b/libs/core/langchain_core/callbacks/tracers/root_listeners.py new file mode 100644 index 00000000000..a693ae1f1a5 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/root_listeners.py @@ -0,0 +1,54 @@ +from typing import Callable, Optional, Union +from uuid import UUID + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.runnables.config import ( + RunnableConfig, + call_func_with_variable_args, +) + +Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] + + +class RootListenersTracer(BaseTracer): + def __init__( + self, + *, + config: RunnableConfig, + on_start: Optional[Listener], + on_end: Optional[Listener], + on_error: Optional[Listener], + ) -> None: + super().__init__() + + self.config = config + self._arg_on_start = on_start + self._arg_on_end = on_end + self._arg_on_error = on_error + self.root_id: Optional[UUID] = None + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + if self.root_id is not None: + return + + self.root_id = run.id + + if self._arg_on_start is not None: + call_func_with_variable_args(self._arg_on_start, run, self.config) + + def _on_run_update(self, run: Run) -> None: + if run.id != self.root_id: + return + + if run.error is None: + if self._arg_on_end is not None: + call_func_with_variable_args(self._arg_on_end, run, self.config) + else: + if self._arg_on_error is not None: + call_func_with_variable_args(self._arg_on_error, run, self.config) diff --git a/libs/core/langchain_core/callbacks/tracers/run_collector.py b/libs/core/langchain_core/callbacks/tracers/run_collector.py new file mode 100644 index 00000000000..e03ab00aa17 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/run_collector.py @@ -0,0 +1,52 @@ +"""A tracer that collects all nested runs in a list.""" + +from typing import Any, List, Optional, Union +from uuid import UUID + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import Run + + +class RunCollectorCallbackHandler(BaseTracer): + """ + A tracer that collects all nested runs in a list. + + This tracer is useful for inspection and evaluation purposes. + + Parameters + ---------- + example_id : Optional[Union[UUID, str]], default=None + The ID of the example being traced. It can be either a UUID or a string. + """ + + name: str = "run-collector_callback_handler" + + def __init__( + self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any + ) -> None: + """ + Initialize the RunCollectorCallbackHandler. + + Parameters + ---------- + example_id : Optional[Union[UUID, str]], default=None + The ID of the example being traced. It can be either a UUID or a string. + """ + super().__init__(**kwargs) + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) + self.traced_runs: List[Run] = [] + + def _persist_run(self, run: Run) -> None: + """ + Persist a run by adding it to the traced_runs list. + + Parameters + ---------- + run : Run + The run to be persisted. + """ + run_ = run.copy() + run_.reference_example_id = self.example_id + self.traced_runs.append(run_) diff --git a/libs/core/langchain_core/callbacks/tracers/schemas.py b/libs/core/langchain_core/callbacks/tracers/schemas.py new file mode 100644 index 00000000000..93436b70a41 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/schemas.py @@ -0,0 +1,140 @@ +"""Schemas for tracers.""" +from __future__ import annotations + +import datetime +import warnings +from typing import Any, Dict, List, Optional, Type +from uuid import UUID + +from langsmith.schemas import RunBase as BaseRunV2 +from langsmith.schemas import RunTypeEnum as RunTypeEnumDep + +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.schema import LLMResult + + +def RunTypeEnum() -> Type[RunTypeEnumDep]: + """RunTypeEnum.""" + warnings.warn( + "RunTypeEnum is deprecated. Please directly use a string instead" + " (e.g. 'llm', 'chain', 'tool').", + DeprecationWarning, + ) + return RunTypeEnumDep + + +class TracerSessionV1Base(BaseModel): + """Base class for TracerSessionV1.""" + + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + name: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +class TracerSessionV1Create(TracerSessionV1Base): + """Create class for TracerSessionV1.""" + + +class TracerSessionV1(TracerSessionV1Base): + """TracerSessionV1 schema.""" + + id: int + + +class TracerSessionBase(TracerSessionV1Base): + """Base class for TracerSession.""" + + tenant_id: UUID + + +class TracerSession(TracerSessionBase): + """TracerSessionV1 schema for the V2 API.""" + + id: UUID + + +class BaseRun(BaseModel): + """Base class for Run.""" + + uuid: str + parent_uuid: Optional[str] = None + start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + extra: Optional[Dict[str, Any]] = None + execution_order: int + child_execution_order: int + serialized: Dict[str, Any] + session_id: int + error: Optional[str] = None + + +class LLMRun(BaseRun): + """Class for LLMRun.""" + + prompts: List[str] + response: Optional[LLMResult] = None + + +class ChainRun(BaseRun): + """Class for ChainRun.""" + + inputs: Dict[str, Any] + outputs: Optional[Dict[str, Any]] = None + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +class ToolRun(BaseRun): + """Class for ToolRun.""" + + tool_input: str + output: Optional[str] = None + action: str + child_llm_runs: List[LLMRun] = Field(default_factory=list) + child_chain_runs: List[ChainRun] = Field(default_factory=list) + child_tool_runs: List[ToolRun] = Field(default_factory=list) + + +# Begin V2 API Schemas + + +class Run(BaseRunV2): + """Run schema for the V2 API in the Tracer.""" + + execution_order: int + child_execution_order: int + child_runs: List[Run] = Field(default_factory=list) + tags: Optional[List[str]] = Field(default_factory=list) + events: List[Dict[str, Any]] = Field(default_factory=list) + + @root_validator(pre=True) + def assign_name(cls, values: dict) -> dict: + """Assign name to the run.""" + if values.get("name") is None: + if "name" in values["serialized"]: + values["name"] = values["serialized"]["name"] + elif "id" in values["serialized"]: + values["name"] = values["serialized"]["id"][-1] + if values.get("events") is None: + values["events"] = [] + return values + + +ChainRun.update_forward_refs() +ToolRun.update_forward_refs() +Run.update_forward_refs() + +__all__ = [ + "BaseRun", + "ChainRun", + "LLMRun", + "Run", + "RunTypeEnum", + "ToolRun", + "TracerSession", + "TracerSessionBase", + "TracerSessionV1", + "TracerSessionV1Base", + "TracerSessionV1Create", +] diff --git a/libs/core/langchain_core/callbacks/tracers/stdout.py b/libs/core/langchain_core/callbacks/tracers/stdout.py new file mode 100644 index 00000000000..8a6b61e3133 --- /dev/null +++ b/libs/core/langchain_core/callbacks/tracers/stdout.py @@ -0,0 +1,178 @@ +import json +from typing import Any, Callable, List + +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.utils.input import get_bolded_text, get_colored_text + + +def try_json_stringify(obj: Any, fallback: str) -> str: + """ + Try to stringify an object to JSON. + Args: + obj: Object to stringify. + fallback: Fallback string to return if the object cannot be stringified. + + Returns: + A JSON string if the object can be stringified, otherwise the fallback string. + + """ + try: + return json.dumps(obj, indent=2, ensure_ascii=False) + except Exception: + return fallback + + +def elapsed(run: Any) -> str: + """Get the elapsed time of a run. + + Args: + run: any object with a start_time and end_time attribute. + + Returns: + A string with the elapsed time in seconds or + milliseconds if time is less than a second. + + """ + elapsed_time = run.end_time - run.start_time + milliseconds = elapsed_time.total_seconds() * 1000 + if milliseconds < 1000: + return f"{milliseconds:.0f}ms" + return f"{(milliseconds / 1000):.2f}s" + + +class FunctionCallbackHandler(BaseTracer): + """Tracer that calls a function with a single str parameter.""" + + name: str = "function_callback_handler" + + def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.function_callback = function + + def _persist_run(self, run: Run) -> None: + pass + + def get_parents(self, run: Run) -> List[Run]: + parents = [] + current_run = run + while current_run.parent_run_id: + parent = self.run_map.get(str(current_run.parent_run_id)) + if parent: + parents.append(parent) + current_run = parent + else: + break + return parents + + def get_breadcrumbs(self, run: Run) -> str: + parents = self.get_parents(run)[::-1] + string = " > ".join( + f"{parent.execution_order}:{parent.run_type}:{parent.name}" + if i != len(parents) - 1 + else f"{parent.execution_order}:{parent.run_type}:{parent.name}" + for i, parent in enumerate(parents + [run]) + ) + return string + + # logging methods + def _on_chain_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/start]', color='green')} " + + get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n") + + f"{try_json_stringify(run.inputs, '[inputs]')}" + ) + + def _on_chain_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/end]', color='blue')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n" + ) + + f"{try_json_stringify(run.outputs, '[outputs]')}" + ) + + def _on_chain_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + run_type = run.run_type.capitalize() + self.function_callback( + f"{get_colored_text('[chain/error]', color='red')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n" + ) + + f"{try_json_stringify(run.error, '[error]')}" + ) + + def _on_llm_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + inputs = ( + {"prompts": [p.strip() for p in run.inputs["prompts"]]} + if "prompts" in run.inputs + else run.inputs + ) + self.function_callback( + f"{get_colored_text('[llm/start]', color='green')} " + + get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n") + + f"{try_json_stringify(inputs, '[inputs]')}" + ) + + def _on_llm_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[llm/end]', color='blue')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n" + ) + + f"{try_json_stringify(run.outputs, '[response]')}" + ) + + def _on_llm_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[llm/error]', color='red')} " + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n" + ) + + f"{try_json_stringify(run.error, '[error]')}" + ) + + def _on_tool_start(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f'{get_colored_text("[tool/start]", color="green")} ' + + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") + + f'"{run.inputs["input"].strip()}"' + ) + + def _on_tool_end(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + if run.outputs: + self.function_callback( + f'{get_colored_text("[tool/end]", color="blue")} ' + + get_bolded_text( + f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" + ) + + f'"{run.outputs["output"].strip()}"' + ) + + def _on_tool_error(self, run: Run) -> None: + crumbs = self.get_breadcrumbs(run) + self.function_callback( + f"{get_colored_text('[tool/error]', color='red')} " + + get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ") + + f"Tool run errored with error:\n" + f"{run.error}" + ) + + +class ConsoleCallbackHandler(FunctionCallbackHandler): + """Tracer that prints to the console.""" + + name: str = "console_callback_handler" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(function=print, **kwargs) diff --git a/libs/core/langchain_core/chat_model.py b/libs/core/langchain_core/chat_model.py new file mode 100644 index 00000000000..ebe77711d39 --- /dev/null +++ b/libs/core/langchain_core/chat_model.py @@ -0,0 +1,735 @@ +import asyncio +import inspect +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + cast, +) + +from langchain_core.callbacks.base import BaseCallbackManager +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) +from langchain_core.globals import get_llm_cache +from langchain_core.load.dump import dumpd, dumps +from langchain_core.prompts.base import StringPromptValue +from langchain_core.prompts.chat import ChatPromptValue +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.runnables import RunnableConfig +from langchain_core.schema import ( + ChatGeneration, + ChatResult, + LLMResult, + PromptValue, + RunInfo, +) +from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput +from langchain_core.schema.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + BaseMessageChunk, + HumanMessage, +) +from langchain_core.schema.output import ChatGenerationChunk + + +def _get_verbosity() -> bool: + from langchain_core.globals import get_verbose + + return get_verbose() + + +def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: + generation: Optional[ChatGenerationChunk] = None + for chunk in stream: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + +async def _agenerate_from_stream( + stream: AsyncIterator[ChatGenerationChunk], +) -> ChatResult: + generation: Optional[ChatGenerationChunk] = None + async for chunk in stream: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + +class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): + """Base class for Chat models.""" + + cache: Optional[bool] = None + """Whether to cache the response.""" + verbose: bool = Field(default_factory=_get_verbosity) + """Whether to print out response text.""" + callbacks: Callbacks = Field(default=None, exclude=True) + """Callbacks to add to the run trace.""" + callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + """Callback manager to add to the run trace.""" + tags: Optional[List[str]] = Field(default=None, exclude=True) + """Tags to add to the run trace.""" + metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) + """Metadata to add to the run trace.""" + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + # --- Runnable methods --- + + @property + def OutputType(self) -> Any: + """Get the output type for this runnable.""" + return AnyMessage + + def _convert_input(self, input: LanguageModelInput) -> PromptValue: + if isinstance(input, PromptValue): + return input + elif isinstance(input, str): + return StringPromptValue(text=input) + elif isinstance(input, list): + return ChatPromptValue(messages=input) + else: + raise ValueError( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = config or {} + return cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ).generations[0][0], + ).message + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = config or {} + llm_result = await self.agenerate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + return cast(ChatGeneration, llm_result.generations[0][0]).message + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[BaseMessageChunk]: + if type(self)._stream == BaseChatModel._stream: + # model doesn't implement streaming, so use default implementation + yield cast( + BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + ) + else: + config = config or {} + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = callback_manager.on_chat_model_start( + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), + ) + try: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + run_manager.on_llm_error(e) + raise e + else: + run_manager.on_llm_end( + LLMResult(generations=[[generation]]), + ) + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[BaseMessageChunk]: + if type(self)._astream == BaseChatModel._astream: + # model doesn't implement streaming, so use default implementation + yield cast( + BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + ) + else: + config = config or {} + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), + ) + try: + generation: Optional[ChatGenerationChunk] = None + async for chunk in self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + await run_manager.on_llm_error(e) + raise e + else: + await run_manager.on_llm_end( + LLMResult(generations=[[generation]]), + ) + + # --- Custom methods --- + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + return {} + + def _get_invocation_params( + self, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> dict: + params = self.dict() + params["stop"] = stop + return {**params, **kwargs} + + def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: + if self.is_lc_serializable(): + params = {**kwargs, **{"stop": stop}} + param_string = str(sorted([(k, v) for k, v in params.items()])) + llm_string = dumps(self) + return llm_string + "---" + param_string + else: + params = self._get_invocation_params(stop=stop, **kwargs) + params = {**params, **kwargs} + return str(sorted([(k, v) for k, v in params.items()])) + + def generate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> LLMResult: + """Top Level call""" + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop} + + callback_manager = CallbackManager.configure( + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, + ) + run_managers = callback_manager.on_chat_model_start( + dumpd(self), + messages, + invocation_params=params, + options=options, + name=run_name, + ) + results = [] + for i, m in enumerate(messages): + try: + results.append( + self._generate_with_cache( + m, + stop=stop, + run_manager=run_managers[i] if run_managers else None, + **kwargs, + ) + ) + except BaseException as e: + if run_managers: + run_managers[i].on_llm_error(e) + raise e + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] + llm_output = self._combine_llm_outputs([res.llm_output for res in results]) + generations = [res.generations for res in results] + output = LLMResult(generations=generations, llm_output=llm_output) + if run_managers: + run_infos = [] + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + run_infos.append(RunInfo(run_id=manager.run_id)) + output.run = run_infos + return output + + async def agenerate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> LLMResult: + """Top Level call""" + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop} + + callback_manager = AsyncCallbackManager.configure( + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, + ) + + run_managers = await callback_manager.on_chat_model_start( + dumpd(self), + messages, + invocation_params=params, + options=options, + name=run_name, + ) + + results = await asyncio.gather( + *[ + self._agenerate_with_cache( + m, + stop=stop, + run_manager=run_managers[i] if run_managers else None, + **kwargs, + ) + for i, m in enumerate(messages) + ], + return_exceptions=True, + ) + exceptions = [] + for i, res in enumerate(results): + if isinstance(res, BaseException): + if run_managers: + await run_managers[i].on_llm_error(res) + exceptions.append(res) + if exceptions: + if run_managers: + await asyncio.gather( + *[ + run_manager.on_llm_end( + LLMResult( + generations=[res.generations], llm_output=res.llm_output + ) + ) + for run_manager, res in zip(run_managers, results) + if not isinstance(res, Exception) + ] + ) + raise exceptions[0] + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] + llm_output = self._combine_llm_outputs([res.llm_output for res in results]) + generations = [res.generations for res in results] + output = LLMResult(generations=generations, llm_output=llm_output) + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> LLMResult: + prompt_messages = [p.to_messages() for p in prompts] + return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) + + async def agenerate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> LLMResult: + prompt_messages = [p.to_messages() for p in prompts] + return await self.agenerate( + prompt_messages, stop=stop, callbacks=callbacks, **kwargs + ) + + def _generate_with_cache( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) + disregard_cache = self.cache is not None and not self.cache + llm_cache = get_llm_cache() + if llm_cache is None or disregard_cache: + # This happens when langchain.cache is None, but self.cache is True + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + if new_arg_supported: + return self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + return self._generate(messages, stop=stop, **kwargs) + else: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + else: + if new_arg_supported: + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = self._generate(messages, stop=stop, **kwargs) + llm_cache.update(prompt, llm_string, result.generations) + return result + + async def _agenerate_with_cache( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) + disregard_cache = self.cache is not None and not self.cache + llm_cache = get_llm_cache() + if llm_cache is None or disregard_cache: + # This happens when langchain.cache is None, but self.cache is True + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + if new_arg_supported: + return await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + return await self._agenerate(messages, stop=stop, **kwargs) + else: + llm_string = self._get_llm_string(stop=stop, **kwargs) + prompt = dumps(messages) + cache_val = llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + return ChatResult(generations=cache_val) + else: + if new_arg_supported: + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = await self._agenerate(messages, stop=stop, **kwargs) + llm_cache.update(prompt, llm_string, result.generations) + return result + + @abstractmethod + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._generate, **kwargs), messages, stop, run_manager + ) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + raise NotImplementedError() + + def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + raise NotImplementedError() + + def __call__( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> BaseMessage: + generation = self.generate( + [messages], stop=stop, callbacks=callbacks, **kwargs + ).generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") + + async def _call_async( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> BaseMessage: + result = await self.agenerate( + [messages], stop=stop, callbacks=callbacks, **kwargs + ) + generation = result.generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") + + def call_as_llm( + self, message: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + return self.predict(message, stop=stop, **kwargs) + + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + result = self([HumanMessage(content=text)], stop=_stop, **kwargs) + if isinstance(result.content, str): + return result.content + else: + raise ValueError("Cannot use predict when output is not a string.") + + def predict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + if stop is None: + _stop = None + else: + _stop = list(stop) + return self(messages, stop=_stop, **kwargs) + + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + result = await self._call_async( + [HumanMessage(content=text)], stop=_stop, **kwargs + ) + if isinstance(result.content, str): + return result.content + else: + raise ValueError("Cannot use predict when output is not a string.") + + async def apredict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(messages, stop=_stop, **kwargs) + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {} + + @property + @abstractmethod + def _llm_type(self) -> str: + """Return type of chat model.""" + + def dict(self, **kwargs: Any) -> Dict: + """Return a dictionary of the LLM.""" + starter_dict = dict(self._identifying_params) + starter_dict["_type"] = self._llm_type + return starter_dict + + +class SimpleChatModel(BaseChatModel): + """Simple Chat Model.""" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) + message = AIMessage(content=output_str) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @abstractmethod + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Simpler interface.""" + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + func = partial( + self._generate, messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/libs/core/langchain_core/env.py b/libs/core/langchain_core/env.py new file mode 100644 index 00000000000..9d85f0fb234 --- /dev/null +++ b/libs/core/langchain_core/env.py @@ -0,0 +1,17 @@ +import platform +from functools import lru_cache + + +@lru_cache(maxsize=1) +def get_runtime_environment() -> dict: + """Get information about the LangChain runtime environment.""" + # Lazy import to avoid circular imports + from langchain_core import __version__ + + return { + "library_version": __version__, + "library": "langchain", + "platform": platform.platform(), + "runtime": "python", + "runtime_version": platform.python_version(), + } diff --git a/libs/core/langchain_core/globals/__init__.py b/libs/core/langchain_core/globals/__init__.py new file mode 100644 index 00000000000..da625899bab --- /dev/null +++ b/libs/core/langchain_core/globals/__init__.py @@ -0,0 +1,197 @@ +# flake8: noqa +"""Global values and configuration that apply to all of LangChain.""" +import warnings +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from langchain_core.schema import BaseCache + + +# DO NOT USE THESE VALUES DIRECTLY! +# Use them only via `get_()` and `set_()` below, +# or else your code may behave unexpectedly with other uses of these global settings: +# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 +_verbose: bool = False +_debug: bool = False +_llm_cache: Optional["BaseCache"] = None + + +def set_verbose(value: bool) -> None: + """Set a new value for the `verbose` global setting.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=( + "Importing verbose from langchain_core root module is no longer supported" + ), + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.verbose` is no longer supported, and once all users + # have migrated to using `set_verbose()` here. + langchain.verbose = value + except ImportError: + pass + + global _verbose + _verbose = value + + +def get_verbose() -> bool: + """Get the value of the `verbose` global setting.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=( + "Importing verbose from langchain_core root module is no longer supported" + ), + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.verbose` is no longer supported, and once all users + # have migrated to using `set_verbose()` here. + # + # In the meantime, the `verbose` setting is considered True if either the old + # or the new value are True. This accommodates users who haven't migrated + # to using `set_verbose()` yet. Those users are getting deprecation warnings + # directing them to use `set_verbose()` when they import `langhchain.verbose`. + old_verbose = langchain.verbose + except ImportError: + old_verbose = False + + global _verbose + return _verbose or old_verbose + + +def set_debug(value: bool) -> None: + """Set a new value for the `debug` global setting.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Importing debug from langchain_core root module is no longer supported", + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.debug` is no longer supported, and once all users + # have migrated to using `set_debug()` here. + langchain.debug = value + except ImportError: + pass + + global _debug + _debug = value + + +def get_debug() -> bool: + """Get the value of the `debug` global setting.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Importing debug from langchain_core root module is no longer supported", + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.debug` is no longer supported, and once all users + # have migrated to using `set_debug()` here. + # + # In the meantime, the `debug` setting is considered True if either the old + # or the new value are True. This accommodates users who haven't migrated + # to using `set_debug()` yet. Those users are getting deprecation warnings + # directing them to use `set_debug()` when they import `langhchain.debug`. + old_debug = langchain.debug + except ImportError: + old_debug = False + + global _debug + return _debug or old_debug + + +def set_llm_cache(value: Optional["BaseCache"]) -> None: + """Set a new LLM cache, overwriting the previous value, if any.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=( + "Importing llm_cache from langchain_core root module is no longer supported" + ), + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.llm_cache` is no longer supported, and + # once all users have migrated to using `set_llm_cache()` here. + langchain.llm_cache = value + except ImportError: + pass + + global _llm_cache + _llm_cache = value + + +def get_llm_cache() -> "BaseCache": + """Get the value of the `llm_cache` global setting.""" + try: + import langchain + + # We're about to run some deprecated code, don't report warnings from it. + # The user called the correct (non-deprecated) code path and shouldn't get warnings. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=( + "Importing llm_cache from langchain_core root module is no longer supported" + ), + ) + # N.B.: This is a workaround for an unfortunate quirk of Python's + # module-level `__getattr__()` implementation: + # https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004 + # + # Remove it once `langchain.llm_cache` is no longer supported, and + # once all users have migrated to using `set_llm_cache()` here. + # + # In the meantime, the `llm_cache` setting returns whichever of + # its two backing sources is truthy (not `None` and non-empty), + # or the old value if both are falsy. This accommodates users + # who haven't migrated to using `set_llm_cache()` yet. + # Those users are getting deprecation warnings directing them + # to use `set_llm_cache()` when they import `langhchain.llm_cache`. + old_llm_cache = langchain.llm_cache + except ImportError: + old_llm_cache = None + + global _llm_cache + return _llm_cache or old_llm_cache diff --git a/libs/core/langchain_core/llm.py b/libs/core/langchain_core/llm.py new file mode 100644 index 00000000000..b48f9230a88 --- /dev/null +++ b/libs/core/langchain_core/llm.py @@ -0,0 +1,1077 @@ +"""Base interface for large language models to expose.""" +from __future__ import annotations + +import asyncio +import functools +import inspect +import json +import logging +import warnings +from abc import ABC, abstractmethod +from functools import partial +from pathlib import Path +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import yaml +from tenacity import ( + RetryCallState, + before_sleep_log, + retry, + retry_base, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain_core.callbacks.base import BaseCallbackManager +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) +from langchain_core.globals import get_llm_cache +from langchain_core.load.dump import dumpd +from langchain_core.prompts.base import StringPromptValue +from langchain_core.prompts.chat import ChatPromptValue +from langchain_core.pydantic_v1 import Field, root_validator, validator +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.config import get_config_list +from langchain_core.schema import Generation, LLMResult, PromptValue, RunInfo +from langchain_core.schema.language_model import BaseLanguageModel, LanguageModelInput +from langchain_core.schema.messages import AIMessage, BaseMessage, get_buffer_string +from langchain_core.schema.output import GenerationChunk + +logger = logging.getLogger(__name__) + + +def _get_verbosity() -> bool: + from langchain_core.globals import get_verbose + + return get_verbose() + + +@functools.lru_cache +def _log_error_once(msg: str) -> None: + """Log an error once.""" + logger.error(msg) + + +def create_base_retry_decorator( + error_types: List[Type[BaseException]], + max_retries: int = 1, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Create a retry decorator for a given LLM and provided list of error types.""" + + _logging = before_sleep_log(logger, logging.WARNING) + + def _before_sleep(retry_state: RetryCallState) -> None: + _logging(retry_state) + if run_manager: + if isinstance(run_manager, AsyncCallbackManagerForLLMRun): + coro = run_manager.on_retry(retry_state) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) + except Exception as e: + _log_error_once(f"Error in on_retry: {e}") + else: + run_manager.on_retry(retry_state) + return None + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + retry_instance: "retry_base" = retry_if_exception_type(error_types[0]) + for error in error_types[1:]: + retry_instance = retry_instance | retry_if_exception_type(error) + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=retry_instance, + before_sleep=_before_sleep, + ) + + +def get_prompts( + params: Dict[str, Any], prompts: List[str] +) -> Tuple[Dict[int, List], str, List[int], List[str]]: + """Get prompts that are already cached.""" + llm_string = str(sorted([(k, v) for k, v in params.items()])) + missing_prompts = [] + missing_prompt_idxs = [] + existing_prompts = {} + llm_cache = get_llm_cache() + for i, prompt in enumerate(prompts): + if llm_cache is not None: + cache_val = llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + existing_prompts[i] = cache_val + else: + missing_prompts.append(prompt) + missing_prompt_idxs.append(i) + return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts + + +def update_cache( + existing_prompts: Dict[int, List], + llm_string: str, + missing_prompt_idxs: List[int], + new_results: LLMResult, + prompts: List[str], +) -> Optional[dict]: + """Update the cache and get the LLM output.""" + llm_cache = get_llm_cache() + for i, result in enumerate(new_results.generations): + existing_prompts[missing_prompt_idxs[i]] = result + prompt = prompts[missing_prompt_idxs[i]] + if llm_cache is not None: + llm_cache.update(prompt, llm_string, result) + llm_output = new_results.llm_output + return llm_output + + +class BaseLLM(BaseLanguageModel[str], ABC): + """Base LLM abstract interface. + + It should take in a prompt and return a string.""" + + cache: Optional[bool] = None + verbose: bool = Field(default_factory=_get_verbosity) + """Whether to print out response text.""" + callbacks: Callbacks = Field(default=None, exclude=True) + callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + tags: Optional[List[str]] = Field(default=None, exclude=True) + """Tags to add to the run trace.""" + metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) + """Metadata to add to the run trace.""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values + + @validator("verbose", pre=True, always=True) + def set_verbose(cls, verbose: Optional[bool]) -> bool: + """If verbose is None, set it. + + This allows users to pass in None as verbose to access the global setting. + """ + if verbose is None: + return _get_verbosity() + else: + return verbose + + # --- Runnable methods --- + + @property + def OutputType(self) -> Type[str]: + """Get the input type for this runnable.""" + return str + + def _convert_input(self, input: LanguageModelInput) -> PromptValue: + if isinstance(input, PromptValue): + return input + elif isinstance(input, str): + return StringPromptValue(text=input) + elif isinstance(input, list): + return ChatPromptValue(messages=input) + else: + raise ValueError( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + config = config or {} + return ( + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + .generations[0][0] + .text + ) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + config = config or {} + llm_result = await self.agenerate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + return llm_result.generations[0][0].text + + def batch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + if not inputs: + return [] + + config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") + + if max_concurrency is None: + try: + llm_result = self.generate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + run_name=[c.get("run_name") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e + else: + batches = [ + inputs[i : i + max_concurrency] + for i in range(0, len(inputs), max_concurrency) + ] + config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] + return [ + output + for i, batch in enumerate(batches) + for output in self.batch( + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, + ) + ] + + async def abatch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + if not inputs: + return [] + config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") + + if max_concurrency is None: + try: + llm_result = await self.agenerate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + run_name=[c.get("run_name") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e + else: + batches = [ + inputs[i : i + max_concurrency] + for i in range(0, len(inputs), max_concurrency) + ] + config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] + return [ + output + for i, batch in enumerate(batches) + for output in await self.abatch( + batch, + config=config[i * max_concurrency : (i + 1) * max_concurrency], + return_exceptions=return_exceptions, + **kwargs, + ) + ] + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[str]: + if type(self)._stream == BaseLLM._stream: + # model doesn't implement streaming, so use default implementation + yield self.invoke(input, config=config, stop=stop, **kwargs) + else: + prompt = self._convert_input(input).to_string() + config = config or {} + params = self.dict() + params["stop"] = stop + params = {**params, **kwargs} + options = {"stop": stop} + callback_manager = CallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = callback_manager.on_llm_start( + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=config.get("run_name"), + ) + try: + generation: Optional[GenerationChunk] = None + for chunk in self._stream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + run_manager.on_llm_error(e) + raise e + else: + run_manager.on_llm_end(LLMResult(generations=[[generation]])) + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + if type(self)._astream == BaseLLM._astream: + # model doesn't implement streaming, so use default implementation + yield await self.ainvoke(input, config=config, stop=stop, **kwargs) + else: + prompt = self._convert_input(input).to_string() + config = config or {} + params = self.dict() + params["stop"] = stop + params = {**params, **kwargs} + options = {"stop": stop} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_llm_start( + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=config.get("run_name"), + ) + try: + generation: Optional[GenerationChunk] = None + async for chunk in self._astream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.text + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + await run_manager.on_llm_error(e) + raise e + else: + await run_manager.on_llm_end(LLMResult(generations=[[generation]])) + + # --- Custom methods --- + + @abstractmethod + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompts.""" + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompts.""" + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._generate, **kwargs), prompts, stop, run_manager + ) + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + raise NotImplementedError() + + def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + raise NotImplementedError() + + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + **kwargs: Any, + ) -> LLMResult: + prompt_strings = [p.to_string() for p in prompts] + return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) + + async def agenerate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + **kwargs: Any, + ) -> LLMResult: + prompt_strings = [p.to_string() for p in prompts] + return await self.agenerate( + prompt_strings, stop=stop, callbacks=callbacks, **kwargs + ) + + def _generate_helper( + self, + prompts: List[str], + stop: Optional[List[str]], + run_managers: List[CallbackManagerForLLMRun], + new_arg_supported: bool, + **kwargs: Any, + ) -> LLMResult: + try: + output = ( + self._generate( + prompts, + stop=stop, + # TODO: support multiple run managers + run_manager=run_managers[0] if run_managers else None, + **kwargs, + ) + if new_arg_supported + else self._generate(prompts, stop=stop) + ) + except BaseException as e: + for run_manager in run_managers: + run_manager.on_llm_error(e) + raise e + flattened_outputs = output.flatten() + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + + def generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + *, + tags: Optional[Union[List[str], List[List[str]]]] = None, + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + run_name: Optional[Union[str, List[str]]] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + if not isinstance(prompts, list): + raise ValueError( + "Argument 'prompts' is expected to be of type List[str], received" + f" argument of type {type(prompts)}." + ) + # Create callback managers + if ( + isinstance(callbacks, list) + and callbacks + and ( + isinstance(callbacks[0], (list, BaseCallbackManager)) + or callbacks[0] is None + ) + ): + # We've received a list of callbacks args to apply to each input + assert len(callbacks) == len(prompts) + assert tags is None or ( + isinstance(tags, list) and len(tags) == len(prompts) + ) + assert metadata is None or ( + isinstance(metadata, list) and len(metadata) == len(prompts) + ) + assert run_name is None or ( + isinstance(run_name, list) and len(run_name) == len(prompts) + ) + callbacks = cast(List[Callbacks], callbacks) + tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + metadata_list = cast( + List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + ) + run_name_list = run_name or cast( + List[Optional[str]], ([None] * len(prompts)) + ) + callback_managers = [ + CallbackManager.configure( + callback, + self.callbacks, + self.verbose, + tag, + self.tags, + meta, + self.metadata, + ) + for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + ] + else: + # We've received a single callbacks arg to apply to all inputs + callback_managers = [ + CallbackManager.configure( + cast(Callbacks, callbacks), + self.callbacks, + self.verbose, + cast(List[str], tags), + self.tags, + cast(Dict[str, Any], metadata), + self.metadata, + ) + ] * len(prompts) + run_name_list = [cast(Optional[str], run_name)] * len(prompts) + + params = self.dict() + params["stop"] = stop + options = {"stop": stop} + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) + disregard_cache = self.cache is not None and not self.cache + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) + if get_llm_cache() is None or disregard_cache: + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + run_managers = [ + callback_manager.on_llm_start( + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=run_name, + )[0] + for callback_manager, prompt, run_name in zip( + callback_managers, prompts, run_name_list + ) + ] + output = self._generate_helper( + prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) + return output + if len(missing_prompts) > 0: + run_managers = [ + callback_managers[idx].on_llm_start( + dumpd(self), + [prompts[idx]], + invocation_params=params, + options=options, + name=run_name_list[idx], + )[0] + for idx in missing_prompt_idxs + ] + new_results = self._generate_helper( + missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) + llm_output = update_cache( + existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts + ) + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) + else: + llm_output = {} + run_info = None + generations = [existing_prompts[i] for i in range(len(prompts))] + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) + + async def _agenerate_helper( + self, + prompts: List[str], + stop: Optional[List[str]], + run_managers: List[AsyncCallbackManagerForLLMRun], + new_arg_supported: bool, + **kwargs: Any, + ) -> LLMResult: + try: + output = ( + await self._agenerate( + prompts, + stop=stop, + run_manager=run_managers[0] if run_managers else None, + **kwargs, + ) + if new_arg_supported + else await self._agenerate(prompts, stop=stop) + ) + except BaseException as e: + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) + raise e + flattened_outputs = output.flatten() + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + + async def agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + *, + tags: Optional[Union[List[str], List[List[str]]]] = None, + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + run_name: Optional[Union[str, List[str]]] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + # Create callback managers + if isinstance(callbacks, list) and ( + isinstance(callbacks[0], (list, BaseCallbackManager)) + or callbacks[0] is None + ): + # We've received a list of callbacks args to apply to each input + assert len(callbacks) == len(prompts) + assert tags is None or ( + isinstance(tags, list) and len(tags) == len(prompts) + ) + assert metadata is None or ( + isinstance(metadata, list) and len(metadata) == len(prompts) + ) + assert run_name is None or ( + isinstance(run_name, list) and len(run_name) == len(prompts) + ) + callbacks = cast(List[Callbacks], callbacks) + tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + metadata_list = cast( + List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + ) + run_name_list = run_name or cast( + List[Optional[str]], ([None] * len(prompts)) + ) + callback_managers = [ + AsyncCallbackManager.configure( + callback, + self.callbacks, + self.verbose, + tag, + self.tags, + meta, + self.metadata, + ) + for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + ] + else: + # We've received a single callbacks arg to apply to all inputs + callback_managers = [ + AsyncCallbackManager.configure( + cast(Callbacks, callbacks), + self.callbacks, + self.verbose, + cast(List[str], tags), + self.tags, + cast(Dict[str, Any], metadata), + self.metadata, + ) + ] * len(prompts) + run_name_list = [cast(Optional[str], run_name)] * len(prompts) + + params = self.dict() + params["stop"] = stop + options = {"stop": stop} + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) + disregard_cache = self.cache is not None and not self.cache + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) + if get_llm_cache() is None or disregard_cache: + if self.cache is not None and self.cache: + raise ValueError( + "Asked to cache, but no cache found at `langchain.cache`." + ) + run_managers = await asyncio.gather( + *[ + callback_manager.on_llm_start( + dumpd(self), + [prompt], + invocation_params=params, + options=options, + name=run_name, + ) + for callback_manager, prompt, run_name in zip( + callback_managers, prompts, run_name_list + ) + ] + ) + run_managers = [r[0] for r in run_managers] + output = await self._agenerate_helper( + prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) + return output + if len(missing_prompts) > 0: + run_managers = await asyncio.gather( + *[ + callback_managers[idx].on_llm_start( + dumpd(self), + [prompts[idx]], + invocation_params=params, + options=options, + name=run_name_list[idx], + ) + for idx in missing_prompt_idxs + ] + ) + run_managers = [r[0] for r in run_managers] + new_results = await self._agenerate_helper( + missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) + llm_output = update_cache( + existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts + ) + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) + else: + llm_output = {} + run_info = None + generations = [existing_prompts[i] for i in range(len(prompts))] + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) + + def __call__( + self, + prompt: str, + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + if not isinstance(prompt, str): + raise ValueError( + "Argument `prompt` is expected to be a string. Instead found " + f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " + "`generate` instead." + ) + return ( + self.generate( + [prompt], + stop=stop, + callbacks=callbacks, + tags=tags, + metadata=metadata, + **kwargs, + ) + .generations[0][0] + .text + ) + + async def _call_async( + self, + prompt: str, + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + result = await self.agenerate( + [prompt], + stop=stop, + callbacks=callbacks, + tags=tags, + metadata=metadata, + **kwargs, + ) + return result.generations[0][0].text + + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + return self(text, stop=_stop, **kwargs) + + def predict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + text = get_buffer_string(messages) + if stop is None: + _stop = None + else: + _stop = list(stop) + content = self(text, stop=_stop, **kwargs) + return AIMessage(content=content) + + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(text, stop=_stop, **kwargs) + + async def apredict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + text = get_buffer_string(messages) + if stop is None: + _stop = None + else: + _stop = list(stop) + content = await self._call_async(text, stop=_stop, **kwargs) + return AIMessage(content=content) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {} + + def __str__(self) -> str: + """Get a string representation of the object for printing.""" + cls_name = f"\033[1m{self.__class__.__name__}\033[0m" + return f"{cls_name}\nParams: {self._identifying_params}" + + @property + @abstractmethod + def _llm_type(self) -> str: + """Return type of llm.""" + + def dict(self, **kwargs: Any) -> Dict: + """Return a dictionary of the LLM.""" + starter_dict = dict(self._identifying_params) + starter_dict["_type"] = self._llm_type + return starter_dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the LLM. + + Args: + file_path: Path to file to save the LLM to. + + Example: + .. code-block:: python + + llm.save(file_path="path/llm.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + prompt_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(prompt_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(prompt_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + +class LLM(BaseLLM): + """Base LLM abstract class. + + The purpose of this class is to expose a simpler interface for working + with LLMs, rather than expect the user to implement the full _generate method. + """ + + @abstractmethod + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Run the LLM on the given prompt and input.""" + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Run the LLM on the given prompt and input.""" + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._call, **kwargs), prompt, stop, run_manager + ) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + # TODO: add caching here. + generations = [] + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + for prompt in prompts: + text = ( + self._call(prompt, stop=stop, run_manager=run_manager, **kwargs) + if new_arg_supported + else self._call(prompt, stop=stop, **kwargs) + ) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + generations = [] + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") + for prompt in prompts: + text = ( + await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) + if new_arg_supported + else await self._acall(prompt, stop=stop, **kwargs) + ) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) diff --git a/libs/core/langchain_core/load/__init__.py b/libs/core/langchain_core/load/__init__.py new file mode 100644 index 00000000000..5232da55bb4 --- /dev/null +++ b/libs/core/langchain_core/load/__init__.py @@ -0,0 +1,6 @@ +"""Serialization and deserialization.""" +from langchain_core.load.dump import dumpd, dumps +from langchain_core.load.load import load, loads +from langchain_core.load.serializable import Serializable + +__all__ = ["dumpd", "dumps", "load", "loads", "Serializable"] diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py new file mode 100644 index 00000000000..2e2293f99e7 --- /dev/null +++ b/libs/core/langchain_core/load/dump.py @@ -0,0 +1,26 @@ +import json +from typing import Any, Dict + +from langchain_core.load.serializable import Serializable, to_json_not_implemented + + +def default(obj: Any) -> Any: + """Return a default value for a Serializable object or + a SerializedNotImplemented object.""" + if isinstance(obj, Serializable): + return obj.to_json() + else: + return to_json_not_implemented(obj) + + +def dumps(obj: Any, *, pretty: bool = False) -> str: + """Return a json string representation of an object.""" + if pretty: + return json.dumps(obj, default=default, indent=2) + else: + return json.dumps(obj, default=default) + + +def dumpd(obj: Any) -> Dict[str, Any]: + """Return a json dict representation of an object.""" + return json.loads(dumps(obj)) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py new file mode 100644 index 00000000000..2a0161ded22 --- /dev/null +++ b/libs/core/langchain_core/load/load.py @@ -0,0 +1,130 @@ +import importlib +import json +import os +from typing import Any, Dict, List, Optional + +from langchain_core.load.serializable import Serializable + +DEFAULT_NAMESPACES = ["langchain", "langchain_core"] + + +class Reviver: + """Reviver for JSON objects.""" + + def __init__( + self, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, + ) -> None: + self.secrets_map = secrets_map or dict() + # By default only support langchain, but user can pass in additional namespaces + self.valid_namespaces = ( + [*DEFAULT_NAMESPACES, *valid_namespaces] + if valid_namespaces + else DEFAULT_NAMESPACES + ) + + def __call__(self, value: Dict[str, Any]) -> Any: + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "secret" + and value.get("id", None) is not None + ): + [key] = value["id"] + if key in self.secrets_map: + return self.secrets_map[key] + else: + if key in os.environ and os.environ[key]: + return os.environ[key] + raise KeyError(f'Missing key "{key}" in load(secrets_map)') + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "not_implemented" + and value.get("id", None) is not None + ): + raise NotImplementedError( + "Trying to load an object that doesn't implement " + f"serialization: {value}" + ) + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "constructor" + and value.get("id", None) is not None + ): + [*namespace, name] = value["id"] + + if namespace[0] not in self.valid_namespaces: + raise ValueError(f"Invalid namespace: {value}") + + # The root namespace "langchain" is not a valid identifier. + if len(namespace) == 1 and namespace[0] == "langchain": + raise ValueError(f"Invalid namespace: {value}") + + mod = importlib.import_module(".".join(namespace)) + cls = getattr(mod, name) + + # The class must be a subclass of Serializable. + if not issubclass(cls, Serializable): + raise ValueError(f"Invalid namespace: {value}") + + # We don't need to recurse on kwargs + # as json.loads will do that for us. + kwargs = value.get("kwargs", dict()) + return cls(**kwargs) + + return value + + +def loads( + text: str, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: + """Revive a LangChain class from a JSON string. + Equivalent to `load(json.loads(text))`. + + Args: + text: The string to load. + secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. + + Returns: + Revived LangChain objects. + """ + return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces)) + + +def load( + obj: Any, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: + """Revive a LangChain class from a JSON object. Use this if you already + have a parsed JSON object, eg. from `json.load` or `orjson.loads`. + + Args: + obj: The object to load. + secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. + + Returns: + Revived LangChain objects. + """ + reviver = Reviver(secrets_map, valid_namespaces) + + def _load(obj: Any) -> Any: + if isinstance(obj, dict): + # Need to revive leaf nodes before reviving this node + loaded_obj = {k: _load(v) for k, v in obj.items()} + return reviver(loaded_obj) + if isinstance(obj, list): + return [_load(o) for o in obj] + return obj + + return _load(obj) diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py new file mode 100644 index 00000000000..e7733d5a0d0 --- /dev/null +++ b/libs/core/langchain_core/load/serializable.py @@ -0,0 +1,207 @@ +from abc import ABC +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast + +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr + + +class BaseSerialized(TypedDict): + """Base class for serialized objects.""" + + lc: int + id: List[str] + + +class SerializedConstructor(BaseSerialized): + """Serialized constructor.""" + + type: Literal["constructor"] + kwargs: Dict[str, Any] + + +class SerializedSecret(BaseSerialized): + """Serialized secret.""" + + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + """Serialized not implemented.""" + + type: Literal["not_implemented"] + repr: Optional[str] + + +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + +class Serializable(BaseModel, ABC): + """Serializable base class.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Is this class serializable?""" + return False + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object. + + For example, if the class is `langchain.llms.openai.OpenAI`, then the + namespace is ["langchain", "llms", "openai"] + """ + return cls.__module__.split(".") + + @property + def lc_secrets(self) -> Dict[str, str]: + """A map of constructor argument names to secret ids. + + For example, + {"openai_api_key": "OPENAI_API_KEY"} + """ + return dict() + + @property + def lc_attributes(self) -> Dict: + """List of attribute names that should be included in the serialized kwargs. + + These attributes must be accepted by the constructor. + """ + return {} + + @classmethod + def lc_id(cls) -> List[str]: + """A unique identifier for this class for serialization purposes. + + The unique identifier is a list of strings that describes the path + to the object. + """ + return [*cls.get_lc_namespace(), cls.__name__] + + class Config: + extra = "ignore" + + def __repr_args__(self) -> Any: + return [ + (k, v) + for k, v in super().__repr_args__() + if (k not in self.__fields__ or try_neq_default(v, k, self)) + ] + + _lc_kwargs = PrivateAttr(default_factory=dict) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lc_kwargs = kwargs + + def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: + if not self.is_lc_serializable(): + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self._lc_kwargs.items() + if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + if cls: + deprecated_attributes = [ + "lc_namespace", + "lc_serializable", + ] + + for attr in deprecated_attributes: + if hasattr(cls, attr): + raise ValueError( + f"Class {self.__class__} has a deprecated " + f"attribute {attr}. Please use the corresponding " + f"classmethod instead." + ) + + # Get a reference to self bound to each class in the MRO + this = cast(Serializable, self if cls is None else super(cls, self)) + + secrets.update(this.lc_secrets) + lc_kwargs.update(this.lc_attributes) + + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + + return { + "lc": 1, + "type": "constructor", + "id": self.lc_id(), + "kwargs": lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: Dict[Any, Any], secrets_map: Dict[str, str] +) -> Dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + """Serialize a "not implemented" object. + + Args: + obj: object to serialize + + Returns: + SerializedNotImplemented + """ + _id: List[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] + except Exception: + pass + + result: SerializedNotImplemented = { + "lc": 1, + "type": "not_implemented", + "id": _id, + "repr": None, + } + try: + result["repr"] = repr(obj) + except Exception: + pass + return result diff --git a/libs/core/langchain_core/output_parsers/__init__.py b/libs/core/langchain_core/output_parsers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py new file mode 100644 index 00000000000..079d204a4e0 --- /dev/null +++ b/libs/core/langchain_core/output_parsers/list.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from abc import abstractmethod +from typing import List + +from langchain_core.schema import BaseOutputParser + + +class ListOutputParser(BaseOutputParser[List[str]]): + """Parse the output of an LLM call to a list.""" + + @property + def _type(self) -> str: + return "list" + + @abstractmethod + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + + +class CommaSeparatedListOutputParser(ListOutputParser): + """Parse the output of an LLM call to a comma-separated list.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + def get_format_instructions(self) -> str: + return ( + "Your response should be a list of comma separated values, " + "eg: `foo, bar, baz`" + ) + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + return text.strip().split(", ") + + @property + def _type(self) -> str: + return "comma-separated-list" + + +class NumberedListOutputParser(ListOutputParser): + """Parse a numbered list.""" + + def get_format_instructions(self) -> str: + return ( + "Your response should be a numbered list with each item on a new line. " + "For example: \n\n1. foo\n\n2. bar\n\n3. baz" + ) + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + pattern = r"\d+\.\s([^\n]+)" + + # Extract the text of each item + matches = re.findall(pattern, text) + return matches + + @property + def _type(self) -> str: + return "numbered-list" + + +class MarkdownListOutputParser(ListOutputParser): + """Parse a markdown list.""" + + def get_format_instructions(self) -> str: + return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + pattern = r"-\s([^\n]+)" + return re.findall(pattern, text) + + @property + def _type(self) -> str: + return "markdown-list" diff --git a/libs/core/langchain_core/prompts/__init__.py b/libs/core/langchain_core/prompts/__init__.py new file mode 100644 index 00000000000..606255c067d --- /dev/null +++ b/libs/core/langchain_core/prompts/__init__.py @@ -0,0 +1,75 @@ +"""**Prompt** is the input to the model. + +Prompt is often constructed +from multiple components. Prompt classes and functions make constructing + and working with prompts easy. + +**Class hierarchy:** + +.. code-block:: + + BasePromptTemplate --> PipelinePromptTemplate + StringPromptTemplate --> PromptTemplate + FewShotPromptTemplate + FewShotPromptWithTemplates + BaseChatPromptTemplate --> AutoGPTPrompt + ChatPromptTemplate --> AgentScratchPadChatPromptTemplate + + + + BaseMessagePromptTemplate --> MessagesPlaceholder + BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate + HumanMessagePromptTemplate + AIMessagePromptTemplate + SystemMessagePromptTemplate + + PromptValue --> StringPromptValue + ChatPromptValue + +""" # noqa: E501 +from langchain_core.prompts.base import StringPromptTemplate +from langchain_core.prompts.chat import ( + AIMessagePromptTemplate, + BaseChatPromptTemplate, + ChatMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, +) +from langchain_core.prompts.example_selector import ( + LengthBasedExampleSelector, + MaxMarginalRelevanceExampleSelector, + SemanticSimilarityExampleSelector, +) +from langchain_core.prompts.few_shot import ( + FewShotChatMessagePromptTemplate, + FewShotPromptTemplate, +) +from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates +from langchain_core.prompts.loading import load_prompt +from langchain_core.prompts.pipeline import PipelinePromptTemplate +from langchain_core.prompts.prompt import Prompt, PromptTemplate +from langchain_core.schema.prompt_template import BasePromptTemplate + +__all__ = [ + "AIMessagePromptTemplate", + "BaseChatPromptTemplate", + "BasePromptTemplate", + "ChatMessagePromptTemplate", + "ChatPromptTemplate", + "FewShotPromptTemplate", + "FewShotPromptWithTemplates", + "HumanMessagePromptTemplate", + "LengthBasedExampleSelector", + "MaxMarginalRelevanceExampleSelector", + "MessagesPlaceholder", + "PipelinePromptTemplate", + "Prompt", + "PromptTemplate", + "SemanticSimilarityExampleSelector", + "StringPromptTemplate", + "SystemMessagePromptTemplate", + "load_prompt", + "FewShotChatMessagePromptTemplate", +] diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py new file mode 100644 index 00000000000..6a1cda6fedf --- /dev/null +++ b/libs/core/langchain_core/prompts/base.py @@ -0,0 +1,173 @@ +"""BasePrompt schema definition.""" +from __future__ import annotations + +import warnings +from abc import ABC +from string import Formatter +from typing import Any, Callable, Dict, List, Literal, Set + +from langchain_core.schema.messages import BaseMessage, HumanMessage +from langchain_core.schema.prompt import PromptValue +from langchain_core.schema.prompt_template import BasePromptTemplate +from langchain_core.utils.formatting import formatter + + +def jinja2_formatter(template: str, **kwargs: Any) -> str: + """Format a template using jinja2. + + *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's + SandboxedEnvironment by default. However, this sand-boxing should + be treated as a best-effort approach rather than a guarantee of security. + Do not accept jinja2 templates from untrusted sources as they may lead + to arbitrary Python code execution. + + https://jinja.palletsprojects.com/en/3.1.x/sandbox/ + """ + try: + from jinja2.sandbox import SandboxedEnvironment + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + "Please be cautious when using jinja2 templates. " + "Do not expand jinja2 templates using unverified or user-controlled " + "inputs as that can result in arbitrary Python code execution." + ) + + # This uses a sandboxed environment to prevent arbitrary code execution. + # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. + # Please treat this sand-boxing as a best-effort approach rather than + # a guarantee of security. + # We recommend to never use jinja2 templates with untrusted inputs. + # https://jinja.palletsprojects.com/en/3.1.x/sandbox/ + # approach not a guarantee of security. + return SandboxedEnvironment().from_string(template).render(**kwargs) + + +def validate_jinja2(template: str, input_variables: List[str]) -> None: + """ + Validate that the input variables are valid for the template. + Issues a warning if missing or extra variables are found. + + Args: + template: The template string. + input_variables: The input variables. + """ + input_variables_set = set(input_variables) + valid_variables = _get_jinja2_variables_from_template(template) + missing_variables = valid_variables - input_variables_set + extra_variables = input_variables_set - valid_variables + + warning_message = "" + if missing_variables: + warning_message += f"Missing variables: {missing_variables} " + + if extra_variables: + warning_message += f"Extra variables: {extra_variables}" + + if warning_message: + warnings.warn(warning_message.strip()) + + +def _get_jinja2_variables_from_template(template: str) -> Set[str]: + try: + from jinja2 import Environment, meta + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + env = Environment() + ast = env.parse(template) + variables = meta.find_undeclared_variables(ast) + return variables + + +DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { + "f-string": formatter.format, + "jinja2": jinja2_formatter, +} + +DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { + "f-string": formatter.validate_input_variables, + "jinja2": validate_jinja2, +} + + +def check_valid_template( + template: str, template_format: str, input_variables: List[str] +) -> None: + """Check that template string is valid. + + Args: + template: The template string. + template_format: The template format. Should be one of "f-string" or "jinja2". + input_variables: The input variables. + + Raises: + ValueError: If the template format is not supported. + """ + if template_format not in DEFAULT_FORMATTER_MAPPING: + valid_formats = list(DEFAULT_FORMATTER_MAPPING) + raise ValueError( + f"Invalid template format. Got `{template_format}`;" + f" should be one of {valid_formats}" + ) + try: + validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] + validator_func(template, input_variables) + except KeyError as e: + raise ValueError( + "Invalid prompt schema; check for mismatched or missing input parameters. " + + str(e) + ) + + +def get_template_variables(template: str, template_format: str) -> List[str]: + """Get the variables from the template. + + Args: + template: The template string. + template_format: The template format. Should be one of "f-string" or "jinja2". + + Returns: + The variables from the template. + + Raises: + ValueError: If the template format is not supported. + """ + if template_format == "jinja2": + # Get the variables for the template + input_variables = _get_jinja2_variables_from_template(template) + elif template_format == "f-string": + input_variables = { + v for _, v, _, _ in Formatter().parse(template) if v is not None + } + else: + raise ValueError(f"Unsupported template format: {template_format}") + + return sorted(input_variables) + + +class StringPromptValue(PromptValue): + """String prompt value.""" + + text: str + """Prompt text.""" + type: Literal["StringPromptValue"] = "StringPromptValue" + + def to_string(self) -> str: + """Return prompt as string.""" + return self.text + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + return [HumanMessage(content=self.text)] + + +class StringPromptTemplate(BasePromptTemplate, ABC): + """String prompt that exposes the format method, returning a prompt.""" + + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + return StringPromptValue(text=self.format(**kwargs)) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py new file mode 100644 index 00000000000..01e2ebde98b --- /dev/null +++ b/libs/core/langchain_core/prompts/chat.py @@ -0,0 +1,748 @@ +"""Chat prompt template.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +from langchain_core._api import deprecated +from langchain_core.load.serializable import Serializable +from langchain_core.prompts.base import StringPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ( + BasePromptTemplate, + PromptValue, +) +from langchain_core.schema.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, + get_buffer_string, +) + + +class BaseMessagePromptTemplate(Serializable, ABC): + """Base class for message prompt templates.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable.""" + return True + + @abstractmethod + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + + @property + @abstractmethod + def input_variables(self) -> List[str]: + """Input variables for this prompt template. + + Returns: + List of input variables. + """ + + def __add__(self, other: Any) -> ChatPromptTemplate: + """Combine two prompt templates. + + Args: + other: Another prompt template. + + Returns: + Combined prompt template. + """ + prompt = ChatPromptTemplate(messages=[self]) + return prompt + other + + +class MessagesPlaceholder(BaseMessagePromptTemplate): + """Prompt template that assumes variable is already list of messages.""" + + variable_name: str + """Name of variable to use as messages.""" + + def __init__(self, variable_name: str, **kwargs: Any): + return super().__init__(variable_name=variable_name, **kwargs) + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessage. + """ + value = kwargs[self.variable_name] + if not isinstance(value, list): + raise ValueError( + f"variable {self.variable_name} should be a list of base messages, " + f"got {value}" + ) + for v in value: + if not isinstance(v, BaseMessage): + raise ValueError( + f"variable {self.variable_name} should be a list of base messages," + f" got {value}" + ) + return value + + @property + def input_variables(self) -> List[str]: + """Input variables for this prompt template. + + Returns: + List of input variable names. + """ + return [self.variable_name] + + +MessagePromptTemplateT = TypeVar( + "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" +) +"""Type variable for message prompt templates.""" + + +class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): + """Base class for message prompt templates that use a string prompt template.""" + + prompt: StringPromptTemplate + """String prompt template.""" + additional_kwargs: dict = Field(default_factory=dict) + """Additional keyword arguments to pass to the prompt template.""" + + @classmethod + def from_template( + cls: Type[MessagePromptTemplateT], + template: str, + template_format: str = "f-string", + **kwargs: Any, + ) -> MessagePromptTemplateT: + """Create a class from a string template. + + Args: + template: a template. + template_format: format of the template. + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ + prompt = PromptTemplate.from_template(template, template_format=template_format) + return cls(prompt=prompt, **kwargs) + + @classmethod + def from_template_file( + cls: Type[MessagePromptTemplateT], + template_file: Union[str, Path], + input_variables: List[str], + **kwargs: Any, + ) -> MessagePromptTemplateT: + """Create a class from a template file. + + Args: + template_file: path to a template file. String or Path. + input_variables: list of input variables. + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ + prompt = PromptTemplate.from_file(template_file, input_variables) + return cls(prompt=prompt, **kwargs) + + @abstractmethod + def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + return [self.format(**kwargs)] + + @property + def input_variables(self) -> List[str]: + """ + Input variables for this prompt template. + + Returns: + List of input variable names. + """ + return self.prompt.input_variables + + +class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): + """Chat message prompt template.""" + + role: str + """Role of the message.""" + + def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + text = self.prompt.format(**kwargs) + return ChatMessage( + content=text, role=self.role, additional_kwargs=self.additional_kwargs + ) + + +class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): + """Human message prompt template. This is a message sent from the user.""" + + def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + text = self.prompt.format(**kwargs) + return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) + + +class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): + """AI message prompt template. This is a message sent from the AI.""" + + def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + text = self.prompt.format(**kwargs) + return AIMessage(content=text, additional_kwargs=self.additional_kwargs) + + +class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): + """System message prompt template. + This is a message that is not sent to the user. + """ + + def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + text = self.prompt.format(**kwargs) + return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) + + +class ChatPromptValue(PromptValue): + """Chat prompt value. + + A type of a prompt value that is built from messages. + """ + + messages: Sequence[BaseMessage] + """List of messages.""" + + def to_string(self) -> str: + """Return prompt as string.""" + return get_buffer_string(self.messages) + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as a list of messages.""" + return list(self.messages) + + +class ChatPromptValueConcrete(ChatPromptValue): + """Chat prompt value which explicitly lists out the message types it accepts. + For use in external schemas.""" + + messages: Sequence[AnyMessage] + + type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" + + +class BaseChatPromptTemplate(BasePromptTemplate, ABC): + """Base class for chat prompt templates.""" + + @property + def lc_attributes(self) -> Dict: + """ + Return a list of attribute names that should be included in the + serialized kwargs. These attributes must be accepted by the + constructor. + """ + return {"input_variables": self.input_variables} + + def format(self, **kwargs: Any) -> str: + """Format the chat template into a string. + + Args: + **kwargs: keyword arguments to use for filling in template variables + in all the template messages in this chat template. + + Returns: + formatted string + """ + return self.format_prompt(**kwargs).to_string() + + def format_prompt(self, **kwargs: Any) -> PromptValue: + """ + Format prompt. Should return a PromptValue. + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + PromptValue. + """ + messages = self.format_messages(**kwargs) + return ChatPromptValue(messages=messages) + + @abstractmethod + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format kwargs into a list of messages.""" + + +MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] + +MessageLikeRepresentation = Union[ + MessageLike, + Tuple[str, str], + Tuple[Type, str], + str, +] + + +class ChatPromptTemplate(BaseChatPromptTemplate): + """A prompt template for chat models. + + Use to create flexible templated prompts for chat models. + + Examples: + + .. code-block:: python + + from langchain_core.prompts import ChatPromptTemplate + + template = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful AI bot. Your name is {name}."), + ("human", "Hello, how are you doing?"), + ("ai", "I'm doing well, thanks!"), + ("human", "{user_input}"), + ]) + + messages = template.format_messages( + name="Bob", + user_input="What is your name?" + ) + """ + + input_variables: List[str] + """List of input variables in template messages. Used for validation.""" + messages: List[MessageLike] + """List of messages consisting of either message prompt templates or messages.""" + validate_template: bool = False + """Whether or not to try validating the template.""" + + def __add__(self, other: Any) -> ChatPromptTemplate: + """Combine two prompt templates. + + Args: + other: Another prompt template. + + Returns: + Combined prompt template. + """ + # Allow for easy combining + if isinstance(other, ChatPromptTemplate): + return ChatPromptTemplate(messages=self.messages + other.messages) + elif isinstance( + other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate) + ): + return ChatPromptTemplate(messages=self.messages + [other]) + elif isinstance(other, (list, tuple)): + _other = ChatPromptTemplate.from_messages(other) + return ChatPromptTemplate(messages=self.messages + _other.messages) + elif isinstance(other, str): + prompt = HumanMessagePromptTemplate.from_template(other) + return ChatPromptTemplate(messages=self.messages + [prompt]) + else: + raise NotImplementedError(f"Unsupported operand type for +: {type(other)}") + + @root_validator(pre=True) + def validate_input_variables(cls, values: dict) -> dict: + """Validate input variables. + + If input_variables is not set, it will be set to the union of + all input variables in the messages. + + Args: + values: values to validate. + + Returns: + Validated values. + """ + messages = values["messages"] + input_vars = set() + input_types: Dict[str, Any] = values.get("input_types", {}) + for message in messages: + if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): + input_vars.update(message.input_variables) + if isinstance(message, MessagesPlaceholder): + if message.variable_name not in input_types: + input_types[message.variable_name] = List[AnyMessage] + if "partial_variables" in values: + input_vars = input_vars - set(values["partial_variables"]) + if "input_variables" in values and values.get("validate_template"): + if input_vars != set(values["input_variables"]): + raise ValueError( + "Got mismatched input_variables. " + f"Expected: {input_vars}. " + f"Got: {values['input_variables']}" + ) + else: + values["input_variables"] = sorted(input_vars) + values["input_types"] = input_types + return values + + @classmethod + def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: + """Create a chat prompt template from a template string. + + Creates a chat template consisting of a single message assumed to be from + the human. + + Args: + template: template string + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ + prompt_template = PromptTemplate.from_template(template, **kwargs) + message = HumanMessagePromptTemplate(prompt=prompt_template) + return cls.from_messages([message]) + + @classmethod + @deprecated("0.0.260", alternative="from_messages classmethod", pending=True) + def from_role_strings( + cls, string_messages: List[Tuple[str, str]] + ) -> ChatPromptTemplate: + """Create a chat prompt template from a list of (role, template) tuples. + + Args: + string_messages: list of (role, template) tuples. + + Returns: + a chat prompt template + """ + return cls( + messages=[ + ChatMessagePromptTemplate.from_template(template, role=role) + for role, template in string_messages + ] + ) + + @classmethod + @deprecated("0.0.260", alternative="from_messages classmethod", pending=True) + def from_strings( + cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] + ) -> ChatPromptTemplate: + """Create a chat prompt template from a list of (role class, template) tuples. + + Args: + string_messages: list of (role class, template) tuples. + + Returns: + a chat prompt template + """ + return cls.from_messages(string_messages) + + @classmethod + def from_messages( + cls, + messages: Sequence[MessageLikeRepresentation], + ) -> ChatPromptTemplate: + """Create a chat prompt template from a variety of message formats. + + Examples: + + Instantiation from a list of message templates: + + .. code-block:: python + + template = ChatPromptTemplate.from_messages([ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ]) + + Instantiation from mixed message formats: + + .. code-block:: python + + template = ChatPromptTemplate.from_messages([ + SystemMessage(content="hello"), + ("human", "Hello, how are you?"), + ]) + + Args: + messages: sequence of message representations. + A message can be represented using the following formats: + (1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of + (message type, template); e.g., ("human", "{user_input}"), + (4) 2-tuple of (message class, template), (4) a string which is + shorthand for ("human", template); e.g., "{user_input}" + + Returns: + a chat prompt template + """ + _messages = [_convert_to_message(message) for message in messages] + + # Automatically infer input variables from messages + input_vars: Set[str] = set() + for _message in _messages: + if isinstance( + _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) + ): + input_vars.update(_message.input_variables) + + return cls(input_variables=sorted(input_vars), messages=_messages) + + def format(self, **kwargs: Any) -> str: + """Format the chat template into a string. + + Args: + **kwargs: keyword arguments to use for filling in template variables + in all the template messages in this chat template. + + Returns: + formatted string + """ + return self.format_prompt(**kwargs).to_string() + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the chat template into a list of finalized messages. + + Args: + **kwargs: keyword arguments to use for filling in template variables + in all the template messages in this chat template. + + Returns: + list of formatted messages + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + result = [] + for message_template in self.messages: + if isinstance(message_template, BaseMessage): + result.extend([message_template]) + elif isinstance( + message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) + ): + rel_params = { + k: v + for k, v in kwargs.items() + if k in message_template.input_variables + } + message = message_template.format_messages(**rel_params) + result.extend(message) + else: + raise ValueError(f"Unexpected input: {message_template}") + return result + + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: + """Get a new ChatPromptTemplate with some input variables already filled in. + + Args: + **kwargs: keyword arguments to use for filling in template variables. Ought + to be a subset of the input variables. + + Returns: + A new ChatPromptTemplate. + + + Example: + + .. code-block:: python + + from langchain_core.prompts import ChatPromptTemplate + + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ("human", "Hi I'm {user}"), + ("ai", "Hi there, {user}, I'm {name}."), + ("human", "{input}"), + ] + ) + template2 = template.partial(user="Lucy", name="R2D2") + + template2.format_messages(input="hello") + """ + prompt_dict = self.__dict__.copy() + prompt_dict["input_variables"] = list( + set(self.input_variables).difference(kwargs) + ) + prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} + return type(self)(**prompt_dict) + + def append(self, message: MessageLikeRepresentation) -> None: + """Append message to the end of the chat template. + + Args: + message: representation of a message to append. + """ + self.messages.append(_convert_to_message(message)) + + def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None: + """Extend the chat template with a sequence of messages.""" + self.messages.extend([_convert_to_message(message) for message in messages]) + + @overload + def __getitem__(self, index: int) -> MessageLike: + ... + + @overload + def __getitem__(self, index: slice) -> ChatPromptTemplate: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[MessageLike, ChatPromptTemplate]: + """Use to index into the chat template.""" + if isinstance(index, slice): + start, stop, step = index.indices(len(self.messages)) + messages = self.messages[start:stop:step] + return ChatPromptTemplate.from_messages(messages) + else: + return self.messages[index] + + def __len__(self) -> int: + """Get the length of the chat template.""" + return len(self.messages) + + @property + def _prompt_type(self) -> str: + """Name of prompt type.""" + return "chat" + + def save(self, file_path: Union[Path, str]) -> None: + """Save prompt to file. + + Args: + file_path: path to file. + """ + raise NotImplementedError() + + +def _create_template_from_message_type( + message_type: str, template: str +) -> BaseMessagePromptTemplate: + """Create a message prompt template from a message type and template string. + + Args: + message_type: str the type of the message template (e.g., "human", "ai", etc.) + template: str the template string. + + Returns: + a message prompt template of the appropriate type. + """ + if message_type in ("human", "user"): + message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template( + template + ) + elif message_type in ("ai", "assistant"): + message = AIMessagePromptTemplate.from_template(template) + elif message_type == "system": + message = SystemMessagePromptTemplate.from_template(template) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human'," + f" 'user', 'ai', 'assistant', or 'system'." + ) + return message + + +def _convert_to_message( + message: MessageLikeRepresentation, +) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - 2-tuple of (message class, template) + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): + _message: Union[ + BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate + ] = message + elif isinstance(message, BaseMessage): + _message = message + elif isinstance(message, str): + _message = _create_template_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + if isinstance(message_type_str, str): + _message = _create_template_from_message_type(message_type_str, template) + else: + _message = message_type_str(prompt=PromptTemplate.from_template(template)) + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message diff --git a/libs/core/langchain_core/prompts/example_selector/__init__.py b/libs/core/langchain_core/prompts/example_selector/__init__.py new file mode 100644 index 00000000000..02eeaf00f7c --- /dev/null +++ b/libs/core/langchain_core/prompts/example_selector/__init__.py @@ -0,0 +1,14 @@ +"""Logic for selecting examples to include in prompts.""" +from langchain_core.prompts.example_selector.length_based import ( + LengthBasedExampleSelector, +) +from langchain_core.prompts.example_selector.semantic_similarity import ( + MaxMarginalRelevanceExampleSelector, + SemanticSimilarityExampleSelector, +) + +__all__ = [ + "LengthBasedExampleSelector", + "MaxMarginalRelevanceExampleSelector", + "SemanticSimilarityExampleSelector", +] diff --git a/libs/core/langchain_core/prompts/example_selector/base.py b/libs/core/langchain_core/prompts/example_selector/base.py new file mode 100644 index 00000000000..ff2e099c810 --- /dev/null +++ b/libs/core/langchain_core/prompts/example_selector/base.py @@ -0,0 +1,15 @@ +"""Interface for selecting examples to include in prompts.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, List + + +class BaseExampleSelector(ABC): + """Interface for selecting examples to include in prompts.""" + + @abstractmethod + def add_example(self, example: Dict[str, str]) -> Any: + """Add new example to store for a key.""" + + @abstractmethod + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Select which examples to use based on the inputs.""" diff --git a/libs/core/langchain_core/prompts/example_selector/length_based.py b/libs/core/langchain_core/prompts/example_selector/length_based.py new file mode 100644 index 00000000000..0604461d6e1 --- /dev/null +++ b/libs/core/langchain_core/prompts/example_selector/length_based.py @@ -0,0 +1,63 @@ +"""Select examples based on length.""" +import re +from typing import Callable, Dict, List + +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, validator + + +def _get_length_based(text: str) -> int: + return len(re.split("\n| ", text)) + + +class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): + """Select examples based on length.""" + + examples: List[dict] + """A list of the examples that the prompt template expects.""" + + example_prompt: PromptTemplate + """Prompt template used to format the examples.""" + + get_text_length: Callable[[str], int] = _get_length_based + """Function to measure prompt length. Defaults to word count.""" + + max_length: int = 2048 + """Max length for the prompt, beyond which examples are cut.""" + + example_text_lengths: List[int] = [] #: :meta private: + + def add_example(self, example: Dict[str, str]) -> None: + """Add new example to list.""" + self.examples.append(example) + string_example = self.example_prompt.format(**example) + self.example_text_lengths.append(self.get_text_length(string_example)) + + @validator("example_text_lengths", always=True) + def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: + """Calculate text lengths if they don't exist.""" + # Check if text lengths were passed in + if v: + return v + # If they were not, calculate them + example_prompt = values["example_prompt"] + get_text_length = values["get_text_length"] + string_examples = [example_prompt.format(**eg) for eg in values["examples"]] + return [get_text_length(eg) for eg in string_examples] + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Select which examples to use based on the input lengths.""" + inputs = " ".join(input_variables.values()) + remaining_length = self.max_length - self.get_text_length(inputs) + i = 0 + examples = [] + while remaining_length > 0 and i < len(self.examples): + new_length = remaining_length - self.example_text_lengths[i] + if new_length < 0: + break + else: + examples.append(self.examples[i]) + remaining_length = new_length + i += 1 + return examples diff --git a/libs/core/langchain_core/prompts/example_selector/semantic_similarity.py b/libs/core/langchain_core/prompts/example_selector/semantic_similarity.py new file mode 100644 index 00000000000..d0b7435c8aa --- /dev/null +++ b/libs/core/langchain_core/prompts/example_selector/semantic_similarity.py @@ -0,0 +1,165 @@ +"""Example selector that selects examples based on SemanticSimilarity.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Type + +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + + +def sorted_values(values: Dict[str, str]) -> List[Any]: + """Return a list of values in dict sorted by key.""" + return [values[val] for val in sorted(values)] + + +class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): + """Example selector that selects examples based on SemanticSimilarity.""" + + vectorstore: VectorStore + """VectorStore than contains information about examples.""" + k: int = 4 + """Number of examples to select.""" + example_keys: Optional[List[str]] = None + """Optional keys to filter examples to.""" + input_keys: Optional[List[str]] = None + """Optional keys to filter input to. If provided, the search is based on + the input variables instead of all variables.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def add_example(self, example: Dict[str, str]) -> str: + """Add new example to vectorstore.""" + if self.input_keys: + string_example = " ".join( + sorted_values({key: example[key] for key in self.input_keys}) + ) + else: + string_example = " ".join(sorted_values(example)) + ids = self.vectorstore.add_texts([string_example], metadatas=[example]) + return ids[0] + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Select which examples to use based on semantic similarity.""" + # Get the docs with the highest similarity. + if self.input_keys: + input_variables = {key: input_variables[key] for key in self.input_keys} + query = " ".join(sorted_values(input_variables)) + example_docs = self.vectorstore.similarity_search(query, k=self.k) + # Get the examples from the metadata. + # This assumes that examples are stored in metadata. + examples = [dict(e.metadata) for e in example_docs] + # If example keys are provided, filter examples to those keys. + if self.example_keys: + examples = [{k: eg[k] for k in self.example_keys} for eg in examples] + return examples + + @classmethod + def from_examples( + cls, + examples: List[dict], + embeddings: Embeddings, + vectorstore_cls: Type[VectorStore], + k: int = 4, + input_keys: Optional[List[str]] = None, + **vectorstore_cls_kwargs: Any, + ) -> SemanticSimilarityExampleSelector: + """Create k-shot example selector using example list and embeddings. + + Reshuffles examples dynamically based on query similarity. + + Args: + examples: List of examples to use in the prompt. + embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings(). + vectorstore_cls: A vector store DB interface class, e.g. FAISS. + k: Number of examples to select + input_keys: If provided, the search is based on the input variables + instead of all variables. + vectorstore_cls_kwargs: optional kwargs containing url for vector store + + Returns: + The ExampleSelector instantiated, backed by a vector store. + """ + if input_keys: + string_examples = [ + " ".join(sorted_values({k: eg[k] for k in input_keys})) + for eg in examples + ] + else: + string_examples = [" ".join(sorted_values(eg)) for eg in examples] + vectorstore = vectorstore_cls.from_texts( + string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs + ) + return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) + + +class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): + """ExampleSelector that selects examples based on Max Marginal Relevance. + + This was shown to improve performance in this paper: + https://arxiv.org/pdf/2211.13892.pdf + """ + + fetch_k: int = 20 + """Number of examples to fetch to rerank.""" + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Select which examples to use based on semantic similarity.""" + # Get the docs with the highest similarity. + if self.input_keys: + input_variables = {key: input_variables[key] for key in self.input_keys} + query = " ".join(sorted_values(input_variables)) + example_docs = self.vectorstore.max_marginal_relevance_search( + query, k=self.k, fetch_k=self.fetch_k + ) + # Get the examples from the metadata. + # This assumes that examples are stored in metadata. + examples = [dict(e.metadata) for e in example_docs] + # If example keys are provided, filter examples to those keys. + if self.example_keys: + examples = [{k: eg[k] for k in self.example_keys} for eg in examples] + return examples + + @classmethod + def from_examples( + cls, + examples: List[dict], + embeddings: Embeddings, + vectorstore_cls: Type[VectorStore], + k: int = 4, + input_keys: Optional[List[str]] = None, + fetch_k: int = 20, + **vectorstore_cls_kwargs: Any, + ) -> MaxMarginalRelevanceExampleSelector: + """Create k-shot example selector using example list and embeddings. + + Reshuffles examples dynamically based on query similarity. + + Args: + examples: List of examples to use in the prompt. + embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). + vectorstore_cls: A vector store DB interface class, e.g. FAISS. + k: Number of examples to select + input_keys: If provided, the search is based on the input variables + instead of all variables. + vectorstore_cls_kwargs: optional kwargs containing url for vector store + + Returns: + The ExampleSelector instantiated, backed by a vector store. + """ + if input_keys: + string_examples = [ + " ".join(sorted_values({k: eg[k] for k in input_keys})) + for eg in examples + ] + else: + string_examples = [" ".join(sorted_values(eg)) for eg in examples] + vectorstore = vectorstore_cls.from_texts( + string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs + ) + return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys) diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py new file mode 100644 index 00000000000..b53c0a7ec56 --- /dev/null +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -0,0 +1,343 @@ +"""Prompt template that contains few shot examples.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +from langchain_core.prompts.base import ( + DEFAULT_FORMATTER_MAPPING, + StringPromptTemplate, + check_valid_template, + get_template_variables, +) +from langchain_core.prompts.chat import ( + BaseChatPromptTemplate, + BaseMessagePromptTemplate, +) +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.messages import BaseMessage, get_buffer_string + + +class _FewShotPromptTemplateMixin(BaseModel): + """Prompt template that contains few shot examples.""" + + examples: Optional[List[dict]] = None + """Examples to format into the prompt. + Either this or example_selector should be provided.""" + + example_selector: Optional[BaseExampleSelector] = None + """ExampleSelector to choose the examples to format into the prompt. + Either this or examples should be provided.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def check_examples_and_selector(cls, values: Dict) -> Dict: + """Check that one and only one of examples/example_selector are provided.""" + examples = values.get("examples", None) + example_selector = values.get("example_selector", None) + if examples and example_selector: + raise ValueError( + "Only one of 'examples' and 'example_selector' should be provided" + ) + + if examples is None and example_selector is None: + raise ValueError( + "One of 'examples' and 'example_selector' should be provided" + ) + + return values + + def _get_examples(self, **kwargs: Any) -> List[dict]: + """Get the examples to use for formatting the prompt. + + Args: + **kwargs: Keyword arguments to be passed to the example selector. + + Returns: + List of examples. + """ + if self.examples is not None: + return self.examples + elif self.example_selector is not None: + return self.example_selector.select_examples(kwargs) + else: + raise ValueError( + "One of 'examples' and 'example_selector' should be provided" + ) + + +class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): + """Prompt template that contains few shot examples.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable.""" + return False + + validate_template: bool = False + """Whether or not to try validating the template.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + example_prompt: PromptTemplate + """PromptTemplate used to format an individual example.""" + + suffix: str + """A prompt template string to put after the examples.""" + + example_separator: str = "\n\n" + """String separator used to join the prefix, the examples, and suffix.""" + + prefix: str = "" + """A prompt template string to put before the examples.""" + + template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that prefix, suffix, and input variables are consistent.""" + if values["validate_template"]: + check_valid_template( + values["prefix"] + values["suffix"], + values["template_format"], + values["input_variables"] + list(values["partial_variables"]), + ) + elif values.get("template_format"): + values["input_variables"] = [ + var + for var in get_template_variables( + values["prefix"] + values["suffix"], values["template_format"] + ) + if var not in values["partial_variables"] + ] + return values + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + **kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + # Get the examples to use. + examples = self._get_examples(**kwargs) + examples = [ + {k: e[k] for k in self.example_prompt.input_variables} for e in examples + ] + # Format the examples. + example_strings = [ + self.example_prompt.format(**example) for example in examples + ] + # Create the overall template. + pieces = [self.prefix, *example_strings, self.suffix] + template = self.example_separator.join([piece for piece in pieces if piece]) + + # Format the template with the input variables. + return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "few_shot" + + def save(self, file_path: Union[Path, str]) -> None: + if self.example_selector: + raise ValueError("Saving an example selector is not currently supported") + return super().save(file_path) + + +class FewShotChatMessagePromptTemplate( + BaseChatPromptTemplate, _FewShotPromptTemplateMixin +): + """Chat prompt template that supports few-shot examples. + + The high level structure of produced by this prompt template is a list of messages + consisting of prefix message(s), example message(s), and suffix message(s). + + This structure enables creating a conversation with intermediate examples like: + + System: You are a helpful AI Assistant + Human: What is 2+2? + AI: 4 + Human: What is 2+3? + AI: 5 + Human: What is 4+4? + + This prompt template can be used to generate a fixed list of examples or else + to dynamically select examples based on the input. + + Examples: + + Prompt template with a fixed list of examples (matching the sample + conversation above): + + .. code-block:: python + + from langchain_core.prompts import ( + FewShotChatMessagePromptTemplate, + ChatPromptTemplate + ) + + examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, + ] + + example_prompt = ChatPromptTemplate.from_messages( + [('human', '{input}'), ('ai', '{output}')] + ) + + few_shot_prompt = FewShotChatMessagePromptTemplate( + examples=examples, + # This is a prompt template used to format each individual example. + example_prompt=example_prompt, + ) + + final_prompt = ChatPromptTemplate.from_messages( + [ + ('system', 'You are a helpful AI Assistant'), + few_shot_prompt, + ('human', '{input}'), + ] + ) + final_prompt.format(input="What is 4+4?") + + Prompt template with dynamically selected examples: + + .. code-block:: python + + from langchain_core.prompts import SemanticSimilarityExampleSelector + from langchain_core.embeddings import OpenAIEmbeddings + from langchain_core.vectorstores import Chroma + + examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, + {"input": "2+4", "output": "6"}, + # ... + ] + + to_vectorize = [ + " ".join(example.values()) + for example in examples + ] + embeddings = OpenAIEmbeddings() + vectorstore = Chroma.from_texts( + to_vectorize, embeddings, metadatas=examples + ) + example_selector = SemanticSimilarityExampleSelector( + vectorstore=vectorstore + ) + + from langchain_core.schema import SystemMessage + from langchain_core.prompts import HumanMessagePromptTemplate + from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate + + few_shot_prompt = FewShotChatMessagePromptTemplate( + # Which variable(s) will be passed to the example selector. + input_variables=["input"], + example_selector=example_selector, + # Define how each example will be formatted. + # In this case, each example will become 2 messages: + # 1 human, and 1 AI + example_prompt=( + HumanMessagePromptTemplate.from_template("{input}") + + AIMessagePromptTemplate.from_template("{output}") + ), + ) + # Define the overall prompt. + final_prompt = ( + SystemMessagePromptTemplate.from_template( + "You are a helpful AI Assistant" + ) + + few_shot_prompt + + HumanMessagePromptTemplate.from_template("{input}") + ) + # Show the prompt + print(final_prompt.format_messages(input="What's 3+3?")) + + # Use within an LLM + from langchain_core.chat_models import ChatAnthropic + chain = final_prompt | ChatAnthropic() + chain.invoke({"input": "What's 3+3?"}) + """ + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable.""" + return False + + input_variables: List[str] = Field(default_factory=list) + """A list of the names of the variables the prompt template will use + to pass to the example_selector, if provided.""" + example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] + """The class to format each example.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format kwargs into a list of messages. + + Args: + **kwargs: keyword arguments to use for filling in templates in messages. + + Returns: + A list of formatted messages with all template variables filled in. + """ + # Get the examples to use. + examples = self._get_examples(**kwargs) + examples = [ + {k: e[k] for k in self.example_prompt.input_variables} for e in examples + ] + # Format the examples. + messages = [ + message + for example in examples + for message in self.example_prompt.format_messages(**example) + ] + return messages + + def format(self, **kwargs: Any) -> str: + """Format the prompt with inputs generating a string. + + Use this method to generate a string representation of a prompt consisting + of chat messages. + + Useful for feeding into a string based completion language model or debugging. + + Args: + **kwargs: keyword arguments to use for formatting. + + Returns: + A string representation of the prompt + """ + messages = self.format_messages(**kwargs) + return get_buffer_string(messages) diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py new file mode 100644 index 00000000000..682a392bb08 --- /dev/null +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -0,0 +1,153 @@ +"""Prompt template that contains few shot examples.""" +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from langchain_core.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, root_validator + + +class FewShotPromptWithTemplates(StringPromptTemplate): + """Prompt template that contains few shot examples.""" + + examples: Optional[List[dict]] = None + """Examples to format into the prompt. + Either this or example_selector should be provided.""" + + example_selector: Optional[BaseExampleSelector] = None + """ExampleSelector to choose the examples to format into the prompt. + Either this or examples should be provided.""" + + example_prompt: PromptTemplate + """PromptTemplate used to format an individual example.""" + + suffix: StringPromptTemplate + """A PromptTemplate to put after the examples.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + example_separator: str = "\n\n" + """String separator used to join the prefix, the examples, and suffix.""" + + prefix: Optional[StringPromptTemplate] = None + """A PromptTemplate to put before the examples.""" + + template_format: str = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + validate_template: bool = False + """Whether or not to try validating the template.""" + + @root_validator(pre=True) + def check_examples_and_selector(cls, values: Dict) -> Dict: + """Check that one and only one of examples/example_selector are provided.""" + examples = values.get("examples", None) + example_selector = values.get("example_selector", None) + if examples and example_selector: + raise ValueError( + "Only one of 'examples' and 'example_selector' should be provided" + ) + + if examples is None and example_selector is None: + raise ValueError( + "One of 'examples' and 'example_selector' should be provided" + ) + + return values + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that prefix, suffix, and input variables are consistent.""" + if values["validate_template"]: + input_variables = values["input_variables"] + expected_input_variables = set(values["suffix"].input_variables) + expected_input_variables |= set(values["partial_variables"]) + if values["prefix"] is not None: + expected_input_variables |= set(values["prefix"].input_variables) + missing_vars = expected_input_variables.difference(input_variables) + if missing_vars: + raise ValueError( + f"Got input_variables={input_variables}, but based on " + f"prefix/suffix expected {expected_input_variables}" + ) + else: + values["input_variables"] = sorted( + set(values["suffix"].input_variables) + | set(values["prefix"].input_variables if values["prefix"] else []) + - set(values["partial_variables"]) + ) + return values + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def _get_examples(self, **kwargs: Any) -> List[dict]: + if self.examples is not None: + return self.examples + elif self.example_selector is not None: + return self.example_selector.select_examples(kwargs) + else: + raise ValueError + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + # Get the examples to use. + examples = self._get_examples(**kwargs) + # Format the examples. + example_strings = [ + self.example_prompt.format(**example) for example in examples + ] + # Create the overall prefix. + if self.prefix is None: + prefix = "" + else: + prefix_kwargs = { + k: v for k, v in kwargs.items() if k in self.prefix.input_variables + } + for k in prefix_kwargs.keys(): + kwargs.pop(k) + prefix = self.prefix.format(**prefix_kwargs) + + # Create the overall suffix + suffix_kwargs = { + k: v for k, v in kwargs.items() if k in self.suffix.input_variables + } + for k in suffix_kwargs.keys(): + kwargs.pop(k) + suffix = self.suffix.format( + **suffix_kwargs, + ) + + pieces = [prefix, *example_strings, suffix] + template = self.example_separator.join([piece for piece in pieces if piece]) + # Format the template with the input variables. + return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "few_shot_with_templates" + + def save(self, file_path: Union[Path, str]) -> None: + if self.example_selector: + raise ValueError("Saving an example selector is not currently supported") + return super().save(file_path) diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py new file mode 100644 index 00000000000..69238db0fea --- /dev/null +++ b/libs/core/langchain_core/prompts/loading.py @@ -0,0 +1,162 @@ +"""Load prompts.""" +import json +import logging +from pathlib import Path +from typing import Callable, Dict, Union + +import yaml + +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import ( + BasePromptTemplate, + StrOutputParser, +) +from langchain_core.utils.loading import try_load_from_hub + +URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" +logger = logging.getLogger(__name__) + + +def load_prompt_from_config(config: dict) -> BasePromptTemplate: + """Load prompt from Config Dict.""" + if "_type" not in config: + logger.warning("No `_type` key found, defaulting to `prompt`.") + config_type = config.pop("_type", "prompt") + + if config_type not in type_to_loader_dict: + raise ValueError(f"Loading {config_type} prompt not supported") + + prompt_loader = type_to_loader_dict[config_type] + return prompt_loader(config) + + +def _load_template(var_name: str, config: dict) -> dict: + """Load template from the path if applicable.""" + # Check if template_path exists in config. + if f"{var_name}_path" in config: + # If it does, make sure template variable doesn't also exist. + if var_name in config: + raise ValueError( + f"Both `{var_name}_path` and `{var_name}` cannot be provided." + ) + # Pop the template path from the config. + template_path = Path(config.pop(f"{var_name}_path")) + # Load the template. + if template_path.suffix == ".txt": + with open(template_path) as f: + template = f.read() + else: + raise ValueError + # Set the template variable to the extracted variable. + config[var_name] = template + return config + + +def _load_examples(config: dict) -> dict: + """Load examples if necessary.""" + if isinstance(config["examples"], list): + pass + elif isinstance(config["examples"], str): + with open(config["examples"]) as f: + if config["examples"].endswith(".json"): + examples = json.load(f) + elif config["examples"].endswith((".yaml", ".yml")): + examples = yaml.safe_load(f) + else: + raise ValueError( + "Invalid file format. Only json or yaml formats are supported." + ) + config["examples"] = examples + else: + raise ValueError("Invalid examples format. Only list or string are supported.") + return config + + +def _load_output_parser(config: dict) -> dict: + """Load output parser.""" + if "output_parser" in config and config["output_parser"]: + _config = config.pop("output_parser") + output_parser_type = _config.pop("_type") + if output_parser_type == "default": + output_parser = StrOutputParser(**_config) + else: + raise ValueError(f"Unsupported output parser {output_parser_type}") + config["output_parser"] = output_parser + return config + + +def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: + """Load the "few shot" prompt from the config.""" + # Load the suffix and prefix templates. + config = _load_template("suffix", config) + config = _load_template("prefix", config) + # Load the example prompt. + if "example_prompt_path" in config: + if "example_prompt" in config: + raise ValueError( + "Only one of example_prompt and example_prompt_path should " + "be specified." + ) + config["example_prompt"] = load_prompt(config.pop("example_prompt_path")) + else: + config["example_prompt"] = load_prompt_from_config(config["example_prompt"]) + # Load the examples. + config = _load_examples(config) + config = _load_output_parser(config) + return FewShotPromptTemplate(**config) + + +def _load_prompt(config: dict) -> PromptTemplate: + """Load the prompt template from config.""" + # Load the template from disk if necessary. + config = _load_template("template", config) + config = _load_output_parser(config) + + template_format = config.get("template_format", "f-string") + if template_format == "jinja2": + # Disabled due to: + # https://github.com/langchain-ai/langchain/issues/4394 + raise ValueError( + f"Loading templates with '{template_format}' format is no longer supported " + f"since it can lead to arbitrary code execution. Please migrate to using " + f"the 'f-string' template format, which does not suffer from this issue." + ) + + return PromptTemplate(**config) + + +def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: + """Unified method for loading a prompt from LangChainHub or local fs.""" + if hub_result := try_load_from_hub( + path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} + ): + return hub_result + else: + return _load_prompt_from_file(path) + + +def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: + """Load prompt from file.""" + # Convert file to a Path object. + if isinstance(file, str): + file_path = Path(file) + else: + file_path = file + # Load from either json or yaml. + if file_path.suffix == ".json": + with open(file_path) as f: + config = json.load(f) + elif file_path.suffix == ".yaml": + with open(file_path, "r") as f: + config = yaml.safe_load(f) + else: + raise ValueError(f"Got unsupported file type {file_path.suffix}") + # Load the prompt from the config now. + return load_prompt_from_config(config) + + +type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { + "prompt": _load_prompt, + "few_shot": _load_few_shot_prompt, +} diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py new file mode 100644 index 00000000000..dc39c592186 --- /dev/null +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, List, Tuple + +from langchain_core.prompts.chat import BaseChatPromptTemplate +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BasePromptTemplate, PromptValue + + +def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: + return {k: inputs[k] for k in input_variables} + + +class PipelinePromptTemplate(BasePromptTemplate): + """A prompt template for composing multiple prompt templates together. + + This can be useful when you want to reuse parts of prompts. + A PipelinePrompt consists of two main parts: + - final_prompt: This is the final prompt that is returned + - pipeline_prompts: This is a list of tuples, consisting + of a string (`name`) and a Prompt Template. + Each PromptTemplate will be formatted and then passed + to future prompt templates as a variable with + the same name as `name` + """ + + final_prompt: BasePromptTemplate + """The final prompt that is returned.""" + pipeline_prompts: List[Tuple[str, BasePromptTemplate]] + """A list of tuples, consisting of a string (`name`) and a Prompt Template.""" + + @root_validator(pre=True) + def get_input_variables(cls, values: Dict) -> Dict: + """Get input variables.""" + created_variables = set() + all_variables = set() + for k, prompt in values["pipeline_prompts"]: + created_variables.add(k) + all_variables.update(prompt.input_variables) + values["input_variables"] = list(all_variables.difference(created_variables)) + return values + + def format_prompt(self, **kwargs: Any) -> PromptValue: + for k, prompt in self.pipeline_prompts: + _inputs = _get_inputs(kwargs, prompt.input_variables) + if isinstance(prompt, BaseChatPromptTemplate): + kwargs[k] = prompt.format_messages(**_inputs) + else: + kwargs[k] = prompt.format(**_inputs) + _inputs = _get_inputs(kwargs, self.final_prompt.input_variables) + return self.final_prompt.format_prompt(**_inputs) + + def format(self, **kwargs: Any) -> str: + return self.format_prompt(**kwargs).to_string() + + @property + def _prompt_type(self) -> str: + raise ValueError diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py new file mode 100644 index 00000000000..349bc2f33b2 --- /dev/null +++ b/libs/core/langchain_core/prompts/prompt.py @@ -0,0 +1,250 @@ +"""Prompt schema definition.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +from langchain_core.prompts.base import ( + DEFAULT_FORMATTER_MAPPING, + StringPromptTemplate, + check_valid_template, + get_template_variables, +) +from langchain_core.pydantic_v1 import root_validator + + +class PromptTemplate(StringPromptTemplate): + """A prompt template for a language model. + + A prompt template consists of a string template. It accepts a set of parameters + from the user that can be used to generate a prompt for a language model. + + The template can be formatted using either f-strings (default) or jinja2 syntax. + + *Security warning*: Prefer using `template_format="f-string"` instead of + `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates + from untrusted sources as they may lead to arbitrary Python code execution. + + As of LangChain 0.0.329, Jinja2 templates will be rendered using + Jinja2's SandboxedEnvironment by default. This sand-boxing should + be treated as a best-effort approach rather than a guarantee of security, + as it is an opt-out rather than opt-in approach. + + Despite the sand-boxing, we recommend to never use jinja2 templates + from untrusted sources. + + Example: + + .. code-block:: python + + from langchain_core.prompts import PromptTemplate + + # Instantiation using from_template (recommended) + prompt = PromptTemplate.from_template("Say {foo}") + prompt.format(foo="bar") + + # Instantiation using initializer + prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") + """ + + @property + def lc_attributes(self) -> Dict[str, Any]: + return { + "template_format": self.template_format, + } + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + template: str + """The prompt template.""" + + template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + validate_template: bool = False + """Whether or not to try validating the template.""" + + def __add__(self, other: Any) -> PromptTemplate: + """Override the + operator to allow for combining prompt templates.""" + # Allow for easy combining + if isinstance(other, PromptTemplate): + if self.template_format != "f-string": + raise ValueError( + "Adding prompt templates only supported for f-strings." + ) + if other.template_format != "f-string": + raise ValueError( + "Adding prompt templates only supported for f-strings." + ) + input_variables = list( + set(self.input_variables) | set(other.input_variables) + ) + template = self.template + other.template + # If any do not want to validate, then don't + validate_template = self.validate_template and other.validate_template + partial_variables = {k: v for k, v in self.partial_variables.items()} + for k, v in other.partial_variables.items(): + if k in partial_variables: + raise ValueError("Cannot have same variable partialed twice.") + else: + partial_variables[k] = v + return PromptTemplate( + template=template, + input_variables=input_variables, + partial_variables=partial_variables, + template_format="f-string", + validate_template=validate_template, + ) + elif isinstance(other, str): + prompt = PromptTemplate.from_template(other) + return self + prompt + else: + raise NotImplementedError(f"Unsupported operand type for +: {type(other)}") + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "prompt" + + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + + @root_validator() + def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + if values["validate_template"]: + all_inputs = values["input_variables"] + list(values["partial_variables"]) + check_valid_template( + values["template"], values["template_format"], all_inputs + ) + elif values.get("template_format"): + values["input_variables"] = [ + var + for var in get_template_variables( + values["template"], values["template_format"] + ) + if var not in values["partial_variables"] + ] + return values + + @classmethod + def from_examples( + cls, + examples: List[str], + suffix: str, + input_variables: List[str], + example_separator: str = "\n\n", + prefix: str = "", + **kwargs: Any, + ) -> PromptTemplate: + """Take examples in list format with prefix and suffix to create a prompt. + + Intended to be used as a way to dynamically create a prompt from examples. + + Args: + examples: List of examples to use in the prompt. + suffix: String to go after the list of examples. Should generally + set up the user's input. + input_variables: A list of variable names the final prompt template + will expect. + example_separator: The separator to use in between examples. Defaults + to two new line characters. + prefix: String that should go before any examples. Generally includes + examples. Default to an empty string. + + Returns: + The final prompt generated. + """ + template = example_separator.join([prefix, *examples, suffix]) + return cls(input_variables=input_variables, template=template, **kwargs) + + @classmethod + def from_file( + cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any + ) -> PromptTemplate: + """Load a prompt from a file. + + Args: + template_file: The path to the file containing the prompt template. + input_variables: A list of variable names the final prompt template + will expect. + + Returns: + The prompt loaded from the file. + """ + with open(str(template_file), "r") as f: + template = f.read() + return cls(input_variables=input_variables, template=template, **kwargs) + + @classmethod + def from_template( + cls, + template: str, + *, + template_format: str = "f-string", + partial_variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> PromptTemplate: + """Load a prompt template from a template. + + *Security warning*: Prefer using `template_format="f-string"` instead of + `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates + from untrusted sources as they may lead to arbitrary Python code execution. + + As of LangChain 0.0.329, Jinja2 templates will be rendered using + Jinja2's SandboxedEnvironment by default. This sand-boxing should + be treated as a best-effort approach rather than a guarantee of security, + as it is an opt-out rather than opt-in approach. + + Despite the sand-boxing, we recommend to never use jinja2 templates + from untrusted sources. + + Args: + template: The template to load. + template_format: The format of the template. Use `jinja2` for jinja2, + and `f-string` or None for f-strings. + partial_variables: A dictionary of variables that can be used to partially + fill in the template. For example, if the template is + `"{variable1} {variable2}"`, and `partial_variables` is + `{"variable1": "foo"}`, then the final prompt will be + `"foo {variable2}"`. + + Returns: + The prompt template loaded from the template. + """ + + input_variables = get_template_variables(template, template_format) + _partial_variables = partial_variables or {} + + if _partial_variables: + input_variables = [ + var for var in input_variables if var not in _partial_variables + ] + + return cls( + input_variables=input_variables, + template=template, + template_format=template_format, + partial_variables=_partial_variables, + **kwargs, + ) + + +# For backwards compatibility. +Prompt = PromptTemplate diff --git a/libs/core/langchain_core/pydantic_v1/__init__.py b/libs/core/langchain_core/pydantic_v1/__init__.py new file mode 100644 index 00000000000..6b52dfcabeb --- /dev/null +++ b/libs/core/langchain_core/pydantic_v1/__init__.py @@ -0,0 +1,23 @@ +from importlib import metadata + +## Create namespaces for pydantic v1 and v2. +# This code must stay at the top of the file before other modules may +# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules. +# +# This hack is done for the following reasons: +# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since +# both dependencies and dependents may be stuck on either version of v1 or v2. +# * Creating namespaces for pydantic v1 and v2 should allow us to write code that +# unambiguously uses either v1 or v2 API. +# * This change is easier to roll out and roll back. + +try: + from pydantic.v1 import * # noqa: F403 # type: ignore +except ImportError: + from pydantic import * # noqa: F403 # type: ignore + + +try: + _PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0]) +except metadata.PackageNotFoundError: + _PYDANTIC_MAJOR_VERSION = 0 diff --git a/libs/core/langchain_core/pydantic_v1/dataclasses.py b/libs/core/langchain_core/pydantic_v1/dataclasses.py new file mode 100644 index 00000000000..bb7253c29d7 --- /dev/null +++ b/libs/core/langchain_core/pydantic_v1/dataclasses.py @@ -0,0 +1,4 @@ +try: + from pydantic.v1.dataclasses import * # noqa: F403 +except ImportError: + from pydantic.dataclasses import * # noqa: F403 diff --git a/libs/core/langchain_core/pydantic_v1/main.py b/libs/core/langchain_core/pydantic_v1/main.py new file mode 100644 index 00000000000..4b8f1670e13 --- /dev/null +++ b/libs/core/langchain_core/pydantic_v1/main.py @@ -0,0 +1,4 @@ +try: + from pydantic.v1.main import * # noqa: F403 +except ImportError: + from pydantic.main import * # noqa: F403 diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py new file mode 100644 index 00000000000..3a1f555d1c8 --- /dev/null +++ b/libs/core/langchain_core/runnables/__init__.py @@ -0,0 +1,57 @@ +"""LangChain **Runnable** and the **LangChain Expression Language (LCEL)**. + +The LangChain Expression Language (LCEL) offers a declarative method to build +production-grade programs that harness the power of LLMs. + +Programs created using LCEL and LangChain Runnables inherently support +synchronous, asynchronous, batch, and streaming operations. + +Support for **async** allows servers hosting LCEL based programs to scale better +for higher concurrent loads. + +**Streaming** of intermediate outputs as they're being generated allows for +creating more responsive UX. + +This module contains schema and implementation of LangChain Runnables primitives. +""" +from langchain_core.runnables.base import ( + Runnable, + RunnableBinding, + RunnableGenerator, + RunnableLambda, + RunnableMap, + RunnableParallel, + RunnableSequence, + RunnableSerializable, +) +from langchain_core.runnables.branch import RunnableBranch +from langchain_core.runnables.config import RunnableConfig, patch_config +from langchain_core.runnables.fallbacks import RunnableWithFallbacks +from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.runnables.router import RouterInput, RouterRunnable +from langchain_core.runnables.utils import ( + ConfigurableField, + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, +) + +__all__ = [ + "ConfigurableField", + "ConfigurableFieldSingleOption", + "ConfigurableFieldMultiOption", + "patch_config", + "RouterInput", + "RouterRunnable", + "Runnable", + "RunnableSerializable", + "RunnableBinding", + "RunnableBranch", + "RunnableConfig", + "RunnableGenerator", + "RunnableLambda", + "RunnableMap", + "RunnableParallel", + "RunnablePassthrough", + "RunnableSequence", + "RunnableWithFallbacks", +] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py new file mode 100644 index 00000000000..384370f9dab --- /dev/null +++ b/libs/core/langchain_core/runnables/base.py @@ -0,0 +1,3026 @@ +from __future__ import annotations + +import asyncio +import inspect +import threading +from abc import ABC, abstractmethod +from concurrent.futures import FIRST_COMPLETED, wait +from functools import partial +from itertools import tee +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Literal, get_args + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch + from langchain_core.callbacks.tracers.root_listeners import Listener + from langchain_core.runnables.fallbacks import ( + RunnableWithFallbacks as RunnableWithFallbacksT, + ) + +from langchain_core.load.dump import dumpd +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.runnables.config import ( + RunnableConfig, + acall_func_with_variable_args, + call_func_with_variable_args, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, + get_config_list, + get_executor_for_config, + merge_configs, + patch_config, +) +from langchain_core.runnables.utils import ( + AddableDict, + AnyConfigurableField, + ConfigurableField, + ConfigurableFieldSpec, + Input, + Output, + accepts_config, + accepts_run_manager, + gather_with_concurrency, + get_function_first_arg_dict_keys, + get_lambda_source, + get_unique_config_specs, + indent_lines_after_first, +) +from langchain_core.utils.aiter import atee, py_anext +from langchain_core.utils.iter import safetee + +Other = TypeVar("Other") + + +class Runnable(Generic[Input, Output], ABC): + """A unit of work that can be invoked, batched, streamed, transformed and composed. + + Key Methods + =========== + + * invoke/ainvoke: Transforms a single input into an output. + * batch/abatch: Efficiently transforms multiple inputs into outputs. + * stream/astream: Streams output from a single input as it's produced. + * astream_log: Streams output and selected intermediate results from an input. + + Built-in optimizations: + + * Batch: By default, batch runs invoke() in parallel using a thread pool executor. + Override to optimize batching. + + * Async: Methods with "a" suffix are asynchronous. By default, they execute + the sync counterpart using asyncio's thread pool. + Override for native async. + + All methods accept an optional config argument, which can be used to configure + execution, add tags and metadata for tracing and debugging etc. + + Runnables expose schematic information about their input, output and config via + the input_schema property, the output_schema property and config_schema method. + + LCEL and Composition + ==================== + + The LangChain Expression Language (LCEL) is a declarative way to compose Runnables + into chains. Any chain constructed this way will automatically have sync, async, + batch, and streaming support. + + The main composition primitives are RunnableSequence and RunnableParallel. + + RunnableSequence invokes a series of runnables sequentially, with one runnable's + output serving as the next's input. Construct using the `|` operator or by + passing a list of runnables to RunnableSequence. + + RunnableParallel invokes runnables concurrently, providing the same input + to each. Construct it using a dict literal within a sequence or by passing a + dict to RunnableParallel. + + + For example, + + .. code-block:: python + + from langchain_core.runnables import RunnableLambda + + # A RunnableSequence constructed using the `|` operator + sequence = RunnableLambda(lambda x: x + 1) | RunnableLambda(lambda x: x * 2) + sequence.invoke(1) # 4 + sequence.batch([1, 2, 3]) # [4, 6, 8] + + + # A sequence that contains a RunnableParallel constructed using a dict literal + sequence = RunnableLambda(lambda x: x + 1) | { + 'mul_2': RunnableLambda(lambda x: x * 2), + 'mul_5': RunnableLambda(lambda x: x * 5) + } + sequence.invoke(1) # {'mul_2': 4, 'mul_5': 10} + + Standard Methods + ================ + + All Runnables expose additional methods that can be used to modify their behavior + (e.g., add a retry policy, add lifecycle listeners, make them configurable, etc.). + + These methods will work on any Runnable, including Runnable chains constructed + by composing other Runnables. See the individual methods for details. + + For example, + + .. code-block:: python + + from langchain_core.runnables import RunnableLambda + + import random + + def add_one(x: int) -> int: + return x + 1 + + + def buggy_double(y: int) -> int: + '''Buggy code that will fail 70% of the time''' + if random.random() > 0.3: + print('This code failed, and will probably be retried!') + raise ValueError('Triggered buggy code') + return y * 2 + + sequence = ( + RunnableLambda(add_one) | + RunnableLambda(buggy_double).with_retry( # Retry on failure + stop_after_attempt=10, + wait_exponential_jitter=False + ) + ) + + print(sequence.input_schema.schema()) # Show inferred input schema + print(sequence.output_schema.schema()) # Show inferred output schema + print(sequence.invoke(2)) # invoke the sequence (note the retry above!!) + + Debugging and tracing + ===================== + + As the chains get longer, it can be useful to be able to see intermediate results + to debug and trace the chain. + + You can set the global debug flag to True to enable debug output for all chains: + + .. code-block:: python + + from langchain_core.globals import set_debug + set_debug(True) + + Alternatively, you can pass existing or custom callbacks to any given chain: + + ... code-block:: python + + from langchain_core.callbacks.tracers import ConsoleCallbackHandler + + chain.invoke( + ..., + config={'callbacks': [ConsoleCallbackHandler()]} + ) + + For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/ + """ + + @property + def InputType(self) -> Type[Input]: + """The type of input this runnable accepts specified as a type annotation.""" + for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] + type_args = get_args(cls) + if type_args and len(type_args) == 2: + return type_args[0] + + raise TypeError( + f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. " + "Override the InputType property to specify the input type." + ) + + @property + def OutputType(self) -> Type[Output]: + """The type of output this runnable produces specified as a type annotation.""" + for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] + type_args = get_args(cls) + if type_args and len(type_args) == 2: + return type_args[1] + + raise TypeError( + f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. " + "Override the OutputType property to specify the output type." + ) + + @property + def input_schema(self) -> Type[BaseModel]: + """The type of input this runnable accepts specified as a pydantic model.""" + return self.get_input_schema() + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """Get a pydantic model that can be used to validate input to the runnable. + + Runnables that leverage the configurable_fields and configurable_alternatives + methods will have a dynamic input schema that depends on which + configuration the runnable is invoked with. + + This method allows to get an input schema for a specific configuration. + + Args: + config: A config to use when generating the schema. + + Returns: + A pydantic model that can be used to validate input. + """ + root_type = self.InputType + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.__class__.__name__ + "Input", __root__=(root_type, None) + ) + + @property + def output_schema(self) -> Type[BaseModel]: + """The type of output this runnable produces specified as a pydantic model.""" + return self.get_output_schema() + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """Get a pydantic model that can be used to validate output to the runnable. + + Runnables that leverage the configurable_fields and configurable_alternatives + methods will have a dynamic output schema that depends on which + configuration the runnable is invoked with. + + This method allows to get an output schema for a specific configuration. + + Args: + config: A config to use when generating the schema. + + Returns: + A pydantic model that can be used to validate output. + """ + root_type = self.OutputType + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.__class__.__name__ + "Output", __root__=(root_type, None) + ) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + """List configurable fields for this runnable.""" + return [] + + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + """The type of config this runnable accepts specified as a pydantic model. + + To mark a field as configurable, see the `configurable_fields` + and `configurable_alternatives` methods. + + Args: + include: A list of fields to include in the config schema. + + Returns: + A pydantic model that can be used to validate config. + """ + + class _Config: + arbitrary_types_allowed = True + + include = include or [] + config_specs = self.config_specs + configurable = ( + create_model( # type: ignore[call-overload] + "Configurable", + **{ + spec.id: ( + spec.annotation, + Field( + spec.default, title=spec.name, description=spec.description + ), + ) + for spec in config_specs + }, + ) + if config_specs + else None + ) + + return create_model( # type: ignore[call-overload] + self.__class__.__name__ + "Config", + __config__=_Config, + **({"configurable": (configurable, None)} if configurable else {}), + **{ + field_name: (field_type, None) + for field_name, field_type in RunnableConfig.__annotations__.items() + if field_name in [i for i in include if i != "configurable"] + }, + ) + + def __or__( + self, + other: Union[ + Runnable[Any, Other], + Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], + ], + ) -> RunnableSerializable[Input, Other]: + """Compose this runnable with another object to create a RunnableSequence.""" + return RunnableSequence(first=self, last=coerce_to_runnable(other)) + + def __ror__( + self, + other: Union[ + Runnable[Other, Any], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], + Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], + ], + ) -> RunnableSerializable[Other, Output]: + """Compose this runnable with another object to create a RunnableSequence.""" + return RunnableSequence(first=coerce_to_runnable(other), last=self) + + """ --- Public API --- """ + + @abstractmethod + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + """Transform a single input into an output. Override to implement. + + Args: + input: The input to the runnable. + config: A config to use when invoking the runnable. + The config supports standard keys like 'tags', 'metadata' for tracing + purposes, 'max_concurrency' for controlling how much work to do + in parallel, and other keys. Please refer to the RunnableConfig + for more details. + + Returns: + The output of the runnable. + """ + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """Default implementation of ainvoke, calls invoke from a thread. + + The default implementation allows usage of async code even if + the runnable did not implement a native async version of invoke. + + Subclasses should override this method if they can run asynchronously. + """ + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, **kwargs), input, config + ) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + """Default implementation runs invoke in parallel using a thread pool executor. + + The default implementation of batch works well for IO bound runnables. + + Subclasses should override this method if they can batch more efficiently; + e.g., if the underlying runnable uses an API which supports a batch mode. + """ + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: + if return_exceptions: + try: + return self.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return self.invoke(input, config, **kwargs) + + # If there's only one input, don't bother with the executor + if len(inputs) == 1: + return cast(List[Output], [invoke(inputs[0], configs[0])]) + + with get_executor_for_config(configs[0]) as executor: + return cast(List[Output], list(executor.map(invoke, inputs, configs))) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + """Default implementation runs ainvoke in parallel using asyncio.gather. + + The default implementation of batch works well for IO bound runnables. + + Subclasses should override this method if they can batch more efficiently; + e.g., if the underlying runnable uses an API which supports a batch mode. + """ + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + async def ainvoke( + input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await self.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await self.ainvoke(input, config, **kwargs) + + coros = map(ainvoke, inputs, configs) + return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + """ + Default implementation of stream, which calls invoke. + Subclasses should override this method if they support streaming output. + """ + yield self.invoke(input, config, **kwargs) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + """ + Default implementation of astream, which calls ainvoke. + Subclasses should override this method if they support streaming output. + """ + yield await self.ainvoke(input, config, **kwargs) + + @overload + def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: Literal[True] = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[RunLogPatch]: + ... + + @overload + def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: Literal[False], + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[RunLog]: + ... + + async def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Optional[Any], + ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: + """ + Stream all output from a runnable, as reported to the callback system. + This includes all inner runs of LLMs, Retrievers, Tools, etc. + + Output is streamed as Log objects, which include a list of + jsonpatch ops that describe how the state of the run has changed in each + step, and the final state of the run. + + The jsonpatch ops can be applied in order to construct state. + """ + + from langchain_core.callbacks.base import BaseCallbackManager + from langchain_core.callbacks.tracers.log_stream import ( + LogStreamCallbackHandler, + RunLog, + RunLogPatch, + ) + + # Create a stream handler that will emit Log objects + stream = LogStreamCallbackHandler( + auto_close=False, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + # Assign the stream handler to the config + config = config or {} + callbacks = config.get("callbacks") + if callbacks is None: + config["callbacks"] = [stream] + elif isinstance(callbacks, list): + config["callbacks"] = callbacks + [stream] + elif isinstance(callbacks, BaseCallbackManager): + callbacks = callbacks.copy() + callbacks.add_handler(stream, inherit=True) + config["callbacks"] = callbacks + else: + raise ValueError( + f"Unexpected type for callbacks: {callbacks}." + "Expected None, list or AsyncCallbackManager." + ) + + # Call the runnable in streaming mode, + # add each chunk to the output stream + async def consume_astream() -> None: + try: + async for chunk in self.astream(input, config, **kwargs): + await stream.send_stream.send( + RunLogPatch( + { + "op": "add", + "path": "/streamed_output/-", + "value": chunk, + } + ) + ) + finally: + await stream.send_stream.aclose() + + # Start the runnable in a task, so we can start consuming output + task = asyncio.create_task(consume_astream()) + + try: + # Yield each chunk from the output stream + if diff: + async for log in stream: + yield log + else: + state = RunLog(state=None) # type: ignore[arg-type] + async for log in stream: + state = state + log + yield state + finally: + # Wait for the runnable to finish, if not cancelled (eg. by break) + try: + await task + except asyncio.CancelledError: + pass + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + """ + Default implementation of transform, which buffers input and then calls stream. + Subclasses should override this method if they can start producing output while + input is still being generated. + """ + final: Input + got_first_val = False + + for chunk in input: + if not got_first_val: + final = chunk + got_first_val = True + else: + # Make a best effort to gather, for any type that supports `+` + # This method should throw an error if gathering fails. + final = final + chunk # type: ignore[operator] + + if got_first_val: + yield from self.stream(final, config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + """ + Default implementation of atransform, which buffers input and calls astream. + Subclasses should override this method if they can start producing output while + input is still being generated. + """ + final: Input + got_first_val = False + + async for chunk in input: + if not got_first_val: + final = chunk + got_first_val = True + else: + # Make a best effort to gather, for any type that supports `+` + # This method should throw an error if gathering fails. + final = final + chunk # type: ignore[operator] + + if got_first_val: + async for output in self.astream(final, config, **kwargs): + yield output + + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: + """ + Bind arguments to a Runnable, returning a new Runnable. + """ + return RunnableBinding(bound=self, kwargs=kwargs, config={}) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + """ + Bind config to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, + config=cast( + RunnableConfig, + {**(config or {}), **kwargs}, + ), # type: ignore[misc] + kwargs={}, + ) + + def with_listeners( + self, + *, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, + ) -> Runnable[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer + + return RunnableBinding( + bound=self, + config_factories=[ + lambda config: { + "callbacks": [ + RootListenersTracer( + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, + ) + ], + } + ], + ) + + def with_types( + self, + *, + input_type: Optional[Type[Input]] = None, + output_type: Optional[Type[Output]] = None, + ) -> Runnable[Input, Output]: + """ + Bind input and output types to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, + custom_input_type=input_type, + custom_output_type=output_type, + kwargs={}, + ) + + def with_retry( + self, + *, + retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), + wait_exponential_jitter: bool = True, + stop_after_attempt: int = 3, + ) -> Runnable[Input, Output]: + """Create a new Runnable that retries the original runnable on exceptions. + + Args: + retry_if_exception_type: A tuple of exception types to retry on + wait_exponential_jitter: Whether to add jitter to the wait time + between retries + stop_after_attempt: The maximum number of attempts to make before giving up + + Returns: + A new Runnable that retries the original runnable on exceptions. + """ + from langchain_core.runnables.retry import RunnableRetry + + return RunnableRetry( + bound=self, + kwargs={}, + config={}, + retry_exception_types=retry_if_exception_type, + wait_exponential_jitter=wait_exponential_jitter, + max_attempt_number=stop_after_attempt, + ) + + def map(self) -> Runnable[List[Input], List[Output]]: + """ + Return a new Runnable that maps a list of inputs to a list of outputs, + by calling invoke() with each input. + """ + return RunnableEach(bound=self) + + def with_fallbacks( + self, + fallbacks: Sequence[Runnable[Input, Output]], + *, + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), + ) -> RunnableWithFallbacksT[Input, Output]: + """Add fallbacks to a runnable, returning a new Runnable. + + Args: + fallbacks: A sequence of runnables to try if the original runnable fails. + exceptions_to_handle: A tuple of exception types to handle. + + Returns: + A new Runnable that will try the original runnable, and then each + fallback in order, upon failures. + """ + from langchain_core.runnables.fallbacks import RunnableWithFallbacks + + return RunnableWithFallbacks( + runnable=self, + fallbacks=fallbacks, + exceptions_to_handle=exceptions_to_handle, + ) + + """ --- Helper methods for Subclasses --- """ + + def _call_with_config( + self, + func: Union[ + Callable[[Input], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], + ], + input: Input, + config: Optional[RunnableConfig], + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> Output: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + try: + output = call_func_with_variable_args( + func, input, config, run_manager, **kwargs + ) + except BaseException as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end(dumpd(output)) + return output + + async def _acall_with_config( + self, + func: Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ], + input: Input, + config: Optional[RunnableConfig], + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> Output: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement ainvoke() in subclasses.""" + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + try: + output = await acall_func_with_variable_args( + func, input, config, run_manager, **kwargs + ) + except BaseException as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end(dumpd(output)) + return output + + def _batch_with_config( + self, + func: Union[ + Callable[[List[Input]], List[Union[Exception, Output]]], + Callable[ + [List[Input], List[CallbackManagerForChainRun]], + List[Union[Exception, Output]], + ], + Callable[ + [List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], + List[Union[Exception, Output]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + if not input: + return [] + + configs = get_config_list(config, len(input)) + callback_managers = [get_callback_manager_for_config(c) for c in configs] + run_managers = [ + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ] + try: + if accepts_config(func): + kwargs["config"] = [ + patch_config(c, callbacks=rm.get_child()) + for c, rm in zip(configs, run_managers) + ] + if accepts_run_manager(func): + kwargs["run_manager"] = run_managers + output = func(input, **kwargs) # type: ignore[call-arg] + except BaseException as e: + for run_manager in run_managers: + run_manager.on_chain_error(e) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + + async def _abatch_with_config( + self, + func: Union[ + Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], + Callable[ + [List[Input], List[AsyncCallbackManagerForChainRun]], + Awaitable[List[Union[Exception, Output]]], + ], + Callable[ + [ + List[Input], + List[AsyncCallbackManagerForChainRun], + List[RunnableConfig], + ], + Awaitable[List[Union[Exception, Output]]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + if not input: + return [] + + configs = get_config_list(config, len(input)) + callback_managers = [get_async_callback_manager_for_config(c) for c in configs] + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ) + ) + try: + if accepts_config(func): + kwargs["config"] = [ + patch_config(c, callbacks=rm.get_child()) + for c, rm in zip(configs, run_managers) + ] + if accepts_run_manager(func): + kwargs["run_manager"] = run_managers + output = await func(input, **kwargs) # type: ignore[call-arg] + except BaseException as e: + await asyncio.gather( + *(run_manager.on_chain_error(e) for run_manager in run_managers) + ) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + + def _transform_stream_with_config( + self, + input: Iterator[Input], + transformer: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], + Callable[ + [ + Iterator[Input], + CallbackManagerForChainRun, + RunnableConfig, + ], + Iterator[Output], + ], + ], + config: Optional[RunnableConfig], + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + """Helper method to transform an Iterator of Input values into an Iterator of + Output values, with callbacks. + Use this to implement `stream()` or `transform()` in Runnable subclasses.""" + # tee the input so we can iterate over it twice + input_for_tracing, input_for_transform = tee(input, 2) + # Start the input iterator to ensure the input runnable starts before this one + final_input: Optional[Input] = next(input_for_tracing, None) + final_input_supported = True + final_output: Optional[Output] = None + final_output_supported = True + + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + {"input": ""}, + run_type=run_type, + name=config.get("run_name"), + ) + try: + if accepts_config(transformer): + kwargs["config"] = patch_config( + config, callbacks=run_manager.get_child() + ) + if accepts_run_manager(transformer): + kwargs["run_manager"] = run_manager + iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] + for chunk in iterator: + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + for ichunk in input_for_tracing: + if final_input_supported: + if final_input is None: + final_input = ichunk + else: + try: + final_input = final_input + ichunk # type: ignore + except TypeError: + final_input = None + final_input_supported = False + except BaseException as e: + run_manager.on_chain_error(e, inputs=final_input) + raise + else: + run_manager.on_chain_end(final_output, inputs=final_input) + + async def _atransform_stream_with_config( + self, + input: AsyncIterator[Input], + transformer: Union[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + Callable[ + [AsyncIterator[Input], AsyncCallbackManagerForChainRun], + AsyncIterator[Output], + ], + Callable[ + [ + AsyncIterator[Input], + AsyncCallbackManagerForChainRun, + RunnableConfig, + ], + AsyncIterator[Output], + ], + ], + config: Optional[RunnableConfig], + run_type: Optional[str] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + """Helper method to transform an Async Iterator of Input values into an Async + Iterator of Output values, with callbacks. + Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" + # tee the input so we can iterate over it twice + input_for_tracing, input_for_transform = atee(input, 2) + # Start the input iterator to ensure the input runnable starts before this one + final_input: Optional[Input] = await py_anext(input_for_tracing, None) + final_input_supported = True + final_output: Optional[Output] = None + final_output_supported = True + + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( + dumpd(self), + {"input": ""}, + run_type=run_type, + name=config.get("run_name"), + ) + try: + if accepts_config(transformer): + kwargs["config"] = patch_config( + config, callbacks=run_manager.get_child() + ) + if accepts_run_manager(transformer): + kwargs["run_manager"] = run_manager + iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] + async for chunk in iterator: + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + async for ichunk in input_for_tracing: + if final_input_supported: + if final_input is None: + final_input = ichunk + else: + try: + final_input = final_input + ichunk # type: ignore[operator] + except TypeError: + final_input = None + final_input_supported = False + except BaseException as e: + await run_manager.on_chain_error(e, inputs=final_input) + raise + else: + await run_manager.on_chain_end(final_output, inputs=final_input) + + +class RunnableSerializable(Serializable, Runnable[Input, Output]): + """A Runnable that can be serialized to JSON.""" + + def configurable_fields( + self, **kwargs: AnyConfigurableField + ) -> RunnableSerializable[Input, Output]: + from langchain_core.runnables.configurable import RunnableConfigurableFields + + for key in kwargs: + if key not in self.__fields__: + raise ValueError( + f"Configuration key {key} not found in {self}: " + "available keys are {self.__fields__.keys()}" + ) + + return RunnableConfigurableFields(default=self, fields=kwargs) + + def configurable_alternatives( + self, + which: ConfigurableField, + default_key: str = "default", + **kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], + ) -> RunnableSerializable[Input, Output]: + from langchain_core.runnables.configurable import ( + RunnableConfigurableAlternatives, + ) + + return RunnableConfigurableAlternatives( + which=which, default=self, alternatives=kwargs, default_key=default_key + ) + + +class RunnableSequence(RunnableSerializable[Input, Output]): + """A sequence of runnables, where the output of each is the input of the next. + + RunnableSequence is the most important composition operator in LangChain as it is + used in virtually every chain. + + A RunnableSequence can be instantiated directly or more commonly by using the `|` + operator where either the left or right operands (or both) must be a Runnable. + + Any RunnableSequence automatically supports sync, async, batch. + + The default implementations of `batch` and `abatch` utilize threadpools and + asyncio gather and will be faster than naive invocation of invoke or ainvoke + for IO bound runnables. + + Batching is implemented by invoking the batch method on each component of the + RunnableSequence in order. + + A RunnableSequence preserves the streaming properties of its components, so if all + components of the sequence implement a `transform` method -- which + is the method that implements the logic to map a streaming input to a streaming + output -- then the sequence will be able to stream input to output! + + If any component of the sequence does not implement transform then the + streaming will only begin after this component is run. If there are + multiple blocking components, streaming begins after the last one. + + Please note: RunnableLambdas do not support `transform` by default! So if + you need to use a RunnableLambdas be careful about where you place them in a + RunnableSequence (if you need to use the .stream()/.astream() methods). + + If you need arbitrary logic and need streaming, you can subclass + Runnable, and implement `transform` for whatever logic you need. + + Here is a simple example that uses simple functions to illustrate the use of + RunnableSequence: + + .. code-block:: python + + from langchain_core.runnables import RunnableLambda + + def add_one(x: int) -> int: + return x + 1 + + def mul_two(x: int) -> int: + return x * 2 + + runnable_1 = RunnableLambda(add_one) + runnable_2 = RunnableLambda(mul_two) + sequence = runnable_1 | runnable_2 + # Or equivalently: + # sequence = RunnableSequence(first=runnable_1, last=runnable_2) + sequence.invoke(1) + await runnable.ainvoke(1) + + sequence.batch([1, 2, 3]) + await sequence.abatch([1, 2, 3]) + + Here's an example that uses streams JSON output generated by an LLM: + + .. code-block:: python + + from langchain_core.output_parsers.json import SimpleJsonOutputParser + from langchain_core.chat_models.openai import ChatOpenAI + + prompt = PromptTemplate.from_template( + 'In JSON format, give me a list of {topic} and their ' + 'corresponding names in French, Spanish and in a ' + 'Cat Language.' + ) + + model = ChatOpenAI() + chain = prompt | model | SimpleJsonOutputParser() + + async for chunk in chain.astream({'topic': 'colors'}): + print('-') + print(chunk, sep='', flush=True) + """ + + # The steps are broken into first, middle and last, solely for type checking + # purposes. It allows specifying the `Input` on the first type, the `Output` of + # the last type. + first: Runnable[Input, Any] + """The first runnable in the sequence.""" + middle: List[Runnable[Any, Any]] = Field(default_factory=list) + """The middle runnables in the sequence.""" + last: Runnable[Any, Output] + """The last runnable in the sequence.""" + + @property + def steps(self) -> List[Runnable[Any, Any]]: + """All the runnables that make up the sequence in order.""" + return [self.first] + self.middle + [self.last] + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + class Config: + arbitrary_types_allowed = True + + @property + def InputType(self) -> Type[Input]: + return self.first.InputType + + @property + def OutputType(self) -> Type[Output]: + return self.last.OutputType + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + from langchain_core.runnables.passthrough import RunnableAssign + + if isinstance(self.first, RunnableAssign): + first = cast(RunnableAssign, self.first) + next_ = self.middle[0] if self.middle else self.last + next_input_schema = next_.get_input_schema(config) + if not next_input_schema.__custom_root_type__: + # it's a dict as expected + return create_model( # type: ignore[call-overload] + "RunnableSequenceInput", + **{ + k: (v.annotation, v.default) + for k, v in next_input_schema.__fields__.items() + if k not in first.mapper.steps + }, + ) + + return self.first.get_input_schema(config) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.last.get_output_schema(config) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.steps for spec in step.config_specs + ) + + def __repr__(self) -> str: + return "\n| ".join( + repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ") + for i, s in enumerate(self.steps) + ) + + def __or__( + self, + other: Union[ + Runnable[Any, Other], + Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], + ], + ) -> RunnableSerializable[Input, Other]: + if isinstance(other, RunnableSequence): + return RunnableSequence( + first=self.first, + middle=self.middle + [self.last] + [other.first] + other.middle, + last=other.last, + ) + else: + return RunnableSequence( + first=self.first, + middle=self.middle + [self.last], + last=coerce_to_runnable(other), + ) + + def __ror__( + self, + other: Union[ + Runnable[Other, Any], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], + Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], + ], + ) -> RunnableSerializable[Other, Output]: + if isinstance(other, RunnableSequence): + return RunnableSequence( + first=other.first, + middle=other.middle + [other.last] + [self.first] + self.middle, + last=self.last, + ) + else: + return RunnableSequence( + first=coerce_to_runnable(other), + middle=[self.first] + self.middle, + last=self.last, + ) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + # invoke all steps in sequence + try: + for i, step in enumerate(self.steps): + input = step.invoke( + input, + # mark each step as a child run + patch_config( + config, callbacks=run_manager.get_child(f"seq:step:{i+1}") + ), + ) + # finish the root run + except BaseException as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end(input) + return cast(Output, input) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + # invoke all steps in sequence + try: + for i, step in enumerate(self.steps): + input = await step.ainvoke( + input, + # mark each step as a child run + patch_config( + config, callbacks=run_manager.get_child(f"seq:step:{i+1}") + ), + ) + # finish the root run + except BaseException as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end(input) + return cast(Output, input) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain_core.callbacks.manager import CallbackManager + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers = [ + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ] + + # invoke + try: + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + for stepidx, step in enumerate(self.steps): + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = step.batch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config( + config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") + ) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for i, step in enumerate(self.steps): + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config( + config, callbacks=rm.get_child(f"seq:step:{i+1}") + ) + for rm, config in zip(run_managers, configs) + ], + ) + + # finish the root runs + except BaseException as e: + for rm in run_managers: + rm.on_chain_error(e) + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise + else: + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + ) + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ) + ) + + # invoke .batch() on each step + # this uses batching optimizations in Runnable subclasses, like LLM + try: + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + for stepidx, step in enumerate(self.steps): + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = await step.abatch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config( + config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") + ) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for i, step in enumerate(self.steps): + inputs = await step.abatch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config( + config, callbacks=rm.get_child(f"seq:step:{i+1}") + ) + for rm, config in zip(run_managers, configs) + ], + ) + # finish the root runs + except BaseException as e: + await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise + else: + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception + + def _transform( + self, + input: Iterator[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Iterator[Output]: + steps = [self.first] + self.middle + [self.last] + + # transform the input stream of each step with the next + # steps that don't natively support transforming an input stream will + # buffer input in memory until all available, and then start emitting output + final_pipeline = cast(Iterator[Output], input) + for step in steps: + final_pipeline = step.transform( + final_pipeline, + patch_config( + config, + callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), + ), + ) + + for output in final_pipeline: + yield output + + async def _atransform( + self, + input: AsyncIterator[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> AsyncIterator[Output]: + steps = [self.first] + self.middle + [self.last] + + # stream the last steps + # transform the input stream of each step with the next + # steps that don't natively support transforming an input stream will + # buffer input in memory until all available, and then start emitting output + final_pipeline = cast(AsyncIterator[Output], input) + for step in steps: + final_pipeline = step.atransform( + final_pipeline, + patch_config( + config, + callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), + ), + ) + async for output in final_pipeline: + yield output + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + yield from self.transform(iter([input]), config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk + + +class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): + """ + A runnable that runs a mapping of runnables in parallel, + and returns a mapping of their outputs. + """ + + steps: Mapping[str, Runnable[Input, Any]] + + def __init__( + self, + __steps: Optional[ + Mapping[ + str, + Union[ + Runnable[Input, Any], + Callable[[Input], Any], + Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], + ], + ] + ] = None, + **kwargs: Union[ + Runnable[Input, Any], + Callable[[Input], Any], + Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], + ], + ) -> None: + merged = {**__steps} if __steps is not None else {} + merged.update(kwargs) + super().__init__( + steps={key: coerce_to_runnable(r) for key, r in merged.items()} + ) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + class Config: + arbitrary_types_allowed = True + + @property + def InputType(self) -> Any: + for step in self.steps.values(): + if step.InputType: + return step.InputType + + return Any + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + if all( + s.get_input_schema(config).schema().get("type", "object") == "object" + for s in self.steps.values() + ): + # This is correct, but pydantic typings/mypy don't think so. + return create_model( # type: ignore[call-overload] + "RunnableParallelInput", + **{ + k: (v.annotation, v.default) + for step in self.steps.values() + for k, v in step.get_input_schema(config).__fields__.items() + if k != "__root__" + }, + ) + + return super().get_input_schema(config) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # This is correct, but pydantic typings/mypy don't think so. + return create_model( # type: ignore[call-overload] + "RunnableParallelOutput", + **{k: (v.OutputType, None) for k, v in self.steps.items()}, + ) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.steps.values() for spec in step.config_specs + ) + + def __repr__(self) -> str: + map_for_repr = ",\n ".join( + f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}" + for k, v in self.steps.items() + ) + return "{\n " + map_for_repr + "\n}" + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + from langchain_core.callbacks.manager import CallbackManager + + # setup callbacks + config = ensure_config(config) + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + # gather results from all steps + try: + # copy to avoid issues from the caller mutating the steps during invoke() + steps = dict(self.steps) + with get_executor_for_config(config) as executor: + futures = [ + executor.submit( + step.invoke, + input, + # mark each step as a child run + patch_config( + config, + callbacks=run_manager.get_child(f"map:key:{key}"), + ), + ) + for key, step in steps.items() + ] + output = {key: future.result() for key, future in zip(steps, futures)} + # finish the root run + except BaseException as e: + run_manager.on_chain_error(e) + raise + else: + run_manager.on_chain_end(output) + return output + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Dict[str, Any]: + # setup callbacks + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + # gather results from all steps + try: + # copy to avoid issues from the caller mutating the steps during invoke() + steps = dict(self.steps) + results = await asyncio.gather( + *( + step.ainvoke( + input, + # mark each step as a child run + patch_config( + config, callbacks=run_manager.get_child(f"map:key:{key}") + ), + ) + for key, step in steps.items() + ) + ) + output = {key: value for key, value in zip(steps, results)} + # finish the root run + except BaseException as e: + await run_manager.on_chain_error(e) + raise + else: + await run_manager.on_chain_end(output) + return output + + def _transform( + self, + input: Iterator[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Iterator[AddableDict]: + # Shallow copy steps to ignore mutations while in progress + steps = dict(self.steps) + # Each step gets a copy of the input iterator, + # which is consumed in parallel in a separate thread. + input_copies = list(safetee(input, len(steps), lock=threading.Lock())) + with get_executor_for_config(config) as executor: + # Create the transform() generator for each step + named_generators = [ + ( + name, + step.transform( + input_copies.pop(), + patch_config( + config, callbacks=run_manager.get_child(f"map:key:{name}") + ), + ), + ) + for name, step in steps.items() + ] + # Start the first iteration of each generator + futures = { + executor.submit(next, generator): (step_name, generator) + for step_name, generator in named_generators + } + # Yield chunks from each as they become available, + # and start the next iteration of that generator that yielded it. + # When all generators are exhausted, stop. + while futures: + completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED) + for future in completed_futures: + (step_name, generator) = futures.pop(future) + try: + chunk = AddableDict({step_name: future.result()}) + yield chunk + futures[executor.submit(next, generator)] = ( + step_name, + generator, + ) + except StopIteration: + pass + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Dict[str, Any]]: + yield from self.transform(iter([input]), config) + + async def _atransform( + self, + input: AsyncIterator[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> AsyncIterator[AddableDict]: + # Shallow copy steps to ignore mutations while in progress + steps = dict(self.steps) + # Each step gets a copy of the input iterator, + # which is consumed in parallel in a separate thread. + input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) + # Create the transform() generator for each step + named_generators = [ + ( + name, + step.atransform( + input_copies.pop(), + patch_config( + config, callbacks=run_manager.get_child(f"map:key:{name}") + ), + ), + ) + for name, step in steps.items() + ] + + # Wrap in a coroutine to satisfy linter + async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: + return await py_anext(generator) + + # Start the first iteration of each generator + tasks = { + asyncio.create_task(get_next_chunk(generator)): (step_name, generator) + for step_name, generator in named_generators + } + # Yield chunks from each as they become available, + # and start the next iteration of the generator that yielded it. + # When all generators are exhausted, stop. + while tasks: + completed_tasks, _ = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in completed_tasks: + (step_name, generator) = tasks.pop(task) + try: + chunk = AddableDict({step_name: task.result()}) + yield chunk + new_task = asyncio.create_task(get_next_chunk(generator)) + tasks[new_task] = (step_name, generator) + except StopAsyncIteration: + pass + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Dict[str, Any]]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config): + yield chunk + + +# We support both names +RunnableMap = RunnableParallel + + +class RunnableGenerator(Runnable[Input, Output]): + """ + A runnable that runs a generator function. + """ + + def __init__( + self, + transform: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + ], + atransform: Optional[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + ] = None, + ) -> None: + if atransform is not None: + self._atransform = atransform + + if inspect.isasyncgenfunction(transform): + self._atransform = transform + elif inspect.isgeneratorfunction(transform): + self._transform = transform + else: + raise TypeError( + "Expected a generator function type for `transform`." + f"Instead got an unsupported type: {type(transform)}" + ) + + @property + def InputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + params = inspect.signature(func).parameters + first_param = next(iter(params.values()), None) + if first_param and first_param.annotation != inspect.Parameter.empty: + return getattr(first_param.annotation, "__args__", (Any,))[0] + else: + return Any + except ValueError: + return Any + + @property + def OutputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + sig = inspect.signature(func) + return ( + getattr(sig.return_annotation, "__args__", (Any,))[0] + if sig.return_annotation != inspect.Signature.empty + else Any + ) + except ValueError: + return Any + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableGenerator): + if hasattr(self, "_transform") and hasattr(other, "_transform"): + return self._transform == other._transform + elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): + return self._atransform == other._atransform + else: + return False + else: + return False + + def __repr__(self) -> str: + return "RunnableGenerator(...)" + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Output]: + return self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Output]: + return self.transform(iter([input]), config, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + for output in self.stream(input, config, **kwargs): + if final is None: + final = output + else: + final = final + output + return cast(Output, final) + + def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Output]: + if not hasattr(self, "_atransform"): + raise NotImplementedError("This runnable does not support async methods.") + + return self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ) + + def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + return self.atransform(input_aiter(), config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + async for output in self.astream(input, config, **kwargs): + if final is None: + final = output + else: + final = final + output + return cast(Output, final) + + +class RunnableLambda(Runnable[Input, Output]): + """RunnableLambda converts a python callable into a Runnable. + + Wrapping a callable in a RunnableLambda makes the callable usable + within either a sync or async context. + + RunnableLambda can be composed as any other Runnable and provides + seamless integration with LangChain tracing. + + Examples: + + .. code-block:: python + + # This is a RunnableLambda + from langchain_core.runnables import RunnableLambda + + def add_one(x: int) -> int: + return x + 1 + + runnable = RunnableLambda(add_one) + + runnable.invoke(1) # returns 2 + runnable.batch([1, 2, 3]) # returns [2, 3, 4] + + # Async is supported by default by delegating to the sync implementation + await runnable.ainvoke(1) # returns 2 + await runnable.abatch([1, 2, 3]) # returns [2, 3, 4] + + + # Alternatively, can provide both synd and sync implementations + async def add_one_async(x: int) -> int: + return x + 1 + + runnable = RunnableLambda(add_one, afunc=add_one_async) + runnable.invoke(1) # Uses add_one + await runnable.ainvoke(1) # Uses add_one_async + """ + + def __init__( + self, + func: Union[ + Union[ + Callable[[Input], Output], + Callable[[Input, RunnableConfig], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], + ], + Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ], + ], + afunc: Optional[ + Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ] + ] = None, + ) -> None: + """Create a RunnableLambda from a callable, and async callable or both. + + Accepts both sync and async variants to allow providing efficient + implementations for sync and async execution. + + Args: + func: Either sync or async callable + afunc: An async callable that takes an input and returns an output. + """ + if afunc is not None: + self.afunc = afunc + + if inspect.iscoroutinefunction(func): + if afunc is not None: + raise TypeError( + "Func was provided as a coroutine function, but afunc was " + "also provided. If providing both, func should be a regular " + "function to avoid ambiguity." + ) + self.afunc = func + elif callable(func): + self.func = cast(Callable[[Input], Output], func) + else: + raise TypeError( + "Expected a callable type for `func`." + f"Instead got an unsupported type: {type(func)}" + ) + + @property + def InputType(self) -> Any: + """The type of the input to this runnable.""" + func = getattr(self, "func", None) or getattr(self, "afunc") + try: + params = inspect.signature(func).parameters + first_param = next(iter(params.values()), None) + if first_param and first_param.annotation != inspect.Parameter.empty: + return first_param.annotation + else: + return Any + except ValueError: + return Any + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """The pydantic schema for the input to this runnable.""" + func = getattr(self, "func", None) or getattr(self, "afunc") + + if isinstance(func, itemgetter): + # This is terrible, but afaict it's not possible to access _items + # on itemgetter objects, so we have to parse the repr + items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ") + if all( + item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items + ): + # It's a dict, lol + return create_model( + "RunnableLambdaInput", + **{item[1:-1]: (Any, None) for item in items}, # type: ignore + ) + else: + return create_model("RunnableLambdaInput", __root__=(List[Any], None)) + + if self.InputType != Any: + return super().get_input_schema(config) + + if dict_keys := get_function_first_arg_dict_keys(func): + return create_model( + "RunnableLambdaInput", + **{key: (Any, None) for key in dict_keys}, # type: ignore + ) + + return super().get_input_schema(config) + + @property + def OutputType(self) -> Any: + """The type of the output of this runnable as a type annotation.""" + func = getattr(self, "func", None) or getattr(self, "afunc") + try: + sig = inspect.signature(func) + return ( + sig.return_annotation + if sig.return_annotation != inspect.Signature.empty + else Any + ) + except ValueError: + return Any + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableLambda): + if hasattr(self, "func") and hasattr(other, "func"): + return self.func == other.func + elif hasattr(self, "afunc") and hasattr(other, "afunc"): + return self.afunc == other.afunc + else: + return False + else: + return False + + def __repr__(self) -> str: + """A string representation of this runnable.""" + if hasattr(self, "func"): + return f"RunnableLambda({get_lambda_source(self.func) or '...'})" + elif hasattr(self, "afunc"): + return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})" + else: + return "RunnableLambda(...)" + + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> Output: + output = call_func_with_variable_args( + self.func, input, config, run_manager, **kwargs + ) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = output.invoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> Output: + output = await acall_func_with_variable_args( + self.afunc, input, config, run_manager, **kwargs + ) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = await output.ainvoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + + def _config( + self, config: Optional[RunnableConfig], callable: Callable[..., Any] + ) -> RunnableConfig: + config = config or {} + + if config.get("run_name") is None: + try: + run_name = callable.__name__ + except AttributeError: + run_name = None + if run_name is not None: + return patch_config(config, run_name=run_name) + + return config + + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + """Invoke this runnable synchronously.""" + if hasattr(self, "func"): + return self._call_with_config( + self._invoke, + input, + self._config(config, self.func), + **kwargs, + ) + else: + raise TypeError( + "Cannot invoke a coroutine function synchronously." + "Use `ainvoke` instead." + ) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + """Invoke this runnable asynchronously.""" + if hasattr(self, "afunc"): + return await self._acall_with_config( + self._ainvoke, + input, + self._config(config, self.afunc), + **kwargs, + ) + else: + # Delegating to super implementation of ainvoke. + # Uses asyncio executor to run the sync version (invoke) + return await super().ainvoke(input, config) + + +class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): + """ + A runnable that delegates calls to another runnable + with each element of the input sequence. + + Use only if creating a new RunnableEach subclass with different __init__ args. + """ + + bound: Runnable[Input, Output] + + class Config: + arbitrary_types_allowed = True + + @property + def InputType(self) -> Any: + return List[self.bound.InputType] # type: ignore[name-defined] + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return create_model( + "RunnableEachInput", + __root__=( + List[self.bound.get_input_schema(config)], # type: ignore + None, + ), + ) + + @property + def OutputType(self) -> Type[List[Output]]: + return List[self.bound.OutputType] # type: ignore[name-defined] + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + schema = self.bound.get_output_schema(config) + return create_model( + "RunnableEachOutput", + __root__=( + List[schema], # type: ignore + None, + ), + ) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return self.bound.config_specs + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + def _invoke( + self, + inputs: List[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> List[Output]: + return self.bound.batch( + inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs + ) + + def invoke( + self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> List[Output]: + return self._call_with_config(self._invoke, input, config, **kwargs) + + async def _ainvoke( + self, + inputs: List[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> List[Output]: + return await self.bound.abatch( + inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs + ) + + async def ainvoke( + self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> List[Output]: + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) + + +class RunnableEach(RunnableEachBase[Input, Output]): + """ + A runnable that delegates calls to another runnable + with each element of the input sequence. + """ + + def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.bind(**kwargs)) + + def with_config( + self, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.with_config(config, **kwargs)) + + def with_listeners( + self, + *, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, + ) -> RunnableEach[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + return RunnableEach( + bound=self.bound.with_listeners( + on_start=on_start, on_end=on_end, on_error=on_error + ) + ) + + +class RunnableBindingBase(RunnableSerializable[Input, Output]): + """ + A runnable that delegates calls to another runnable with a set of kwargs. + + Use only if creating a new RunnableBinding subclass with different __init__ args. + """ + + bound: Runnable[Input, Output] + + kwargs: Mapping[str, Any] = Field(default_factory=dict) + + config: RunnableConfig = Field(default_factory=dict) + + config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field( + default_factory=list + ) + + # Union[Type[Input], BaseModel] + things like List[str] + custom_input_type: Optional[Any] = None + # Union[Type[Output], BaseModel] + things like List[str] + custom_output_type: Optional[Any] = None + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + *, + bound: Runnable[Input, Output], + kwargs: Optional[Mapping[str, Any]] = None, + config: Optional[RunnableConfig] = None, + config_factories: Optional[ + List[Callable[[RunnableConfig], RunnableConfig]] + ] = None, + custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, + custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, + **other_kwargs: Any, + ) -> None: + config = config or {} + # config_specs contains the list of valid `configurable` keys + if configurable := config.get("configurable", None): + allowed_keys = set(s.id for s in bound.config_specs) + for key in configurable: + if key not in allowed_keys: + raise ValueError( + f"Configurable key '{key}' not found in runnable with" + f" config keys: {allowed_keys}" + ) + super().__init__( + bound=bound, + kwargs=kwargs or {}, + config=config or {}, + config_factories=config_factories or [], + custom_input_type=custom_input_type, + custom_output_type=custom_output_type, + **other_kwargs, + ) + + @property + def InputType(self) -> Type[Input]: + return ( + cast(Type[Input], self.custom_input_type) + if self.custom_input_type is not None + else self.bound.InputType + ) + + @property + def OutputType(self) -> Type[Output]: + return ( + cast(Type[Output], self.custom_output_type) + if self.custom_output_type is not None + else self.bound.OutputType + ) + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + if self.custom_input_type is not None: + return super().get_input_schema(config) + return self.bound.get_input_schema(merge_configs(self.config, config)) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + if self.custom_output_type is not None: + return super().get_output_schema(config) + return self.bound.get_output_schema(merge_configs(self.config, config)) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return self.bound.config_specs + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: + config = merge_configs(self.config, *configs) + return merge_configs(config, *(f(config) for f in self.config_factories)) + + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + return self.bound.invoke( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + return await self.bound.ainvoke( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + if isinstance(config, list): + configs = cast( + List[RunnableConfig], + [self._merge_configs(conf) for conf in config], + ) + else: + configs = [self._merge_configs(config) for _ in range(len(inputs))] + return self.bound.batch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + if isinstance(config, list): + configs = cast( + List[RunnableConfig], + [self._merge_configs(conf) for conf in config], + ) + else: + configs = [self._merge_configs(config) for _ in range(len(inputs))] + return await self.bound.abatch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + yield from self.bound.stream( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for item in self.bound.astream( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ): + yield item + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Output]: + yield from self.bound.transform( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Output]: + async for item in self.bound.atransform( + input, + self._merge_configs(config), + **{**self.kwargs, **kwargs}, + ): + yield item + + +RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig) + + +class RunnableBinding(RunnableBindingBase[Input, Output]): + """ + A runnable that delegates calls to another runnable with a set of kwargs. + """ + + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + config=self.config, + kwargs={**self.kwargs, **kwargs}, + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_listeners( + self, + *, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, + ) -> Runnable[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer + + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + config_factories=[ + lambda config: { + "callbacks": [ + RootListenersTracer( + config=config, + on_start=on_start, + on_end=on_end, + on_error=on_error, + ) + ], + } + ], + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + + def with_types( + self, + input_type: Optional[Union[Type[Input], BaseModel]] = None, + output_type: Optional[Union[Type[Output], BaseModel]] = None, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + custom_input_type=input_type + if input_type is not None + else self.custom_input_type, + custom_output_type=output_type + if output_type is not None + else self.custom_output_type, + ) + + def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound.with_retry(**kwargs), + kwargs=self.kwargs, + config=self.config, + ) + + +RunnableLike = Union[ + Runnable[Input, Output], + Callable[[Input], Output], + Callable[[Input], Awaitable[Output]], + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + Mapping[str, Any], +] + + +def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: + """Coerce a runnable-like object into a Runnable. + + Args: + thing: A runnable-like object. + + Returns: + A Runnable. + """ + if isinstance(thing, Runnable): + return thing + elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + return RunnableGenerator(thing) + elif callable(thing): + return RunnableLambda(cast(Callable[[Input], Output], thing)) + elif isinstance(thing, dict): + return cast(Runnable[Input, Output], RunnableParallel(thing)) + else: + raise TypeError( + f"Expected a Runnable, callable or dict." + f"Instead got an unsupported type: {type(thing)}" + ) diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py new file mode 100644 index 00000000000..96e9685dc6d --- /dev/null +++ b/libs/core/langchain_core/runnables/branch.py @@ -0,0 +1,254 @@ +from typing import ( + Any, + Awaitable, + Callable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from langchain_core.load.dump import dumpd +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables.base import ( + Runnable, + RunnableLike, + RunnableSerializable, + coerce_to_runnable, +) +from langchain_core.runnables.config import ( + RunnableConfig, + ensure_config, + get_callback_manager_for_config, + patch_config, +) +from langchain_core.runnables.utils import ( + ConfigurableFieldSpec, + Input, + Output, + get_unique_config_specs, +) + + +class RunnableBranch(RunnableSerializable[Input, Output]): + """A Runnable that selects which branch to run based on a condition. + + The runnable is initialized with a list of (condition, runnable) pairs and + a default branch. + + When operating on an input, the first condition that evaluates to True is + selected, and the corresponding runnable is run on the input. + + If no condition evaluates to True, the default branch is run on the input. + + Examples: + + .. code-block:: python + + from langchain_core.runnables import RunnableBranch + + branch = RunnableBranch( + (lambda x: isinstance(x, str), lambda x: x.upper()), + (lambda x: isinstance(x, int), lambda x: x + 1), + (lambda x: isinstance(x, float), lambda x: x * 2), + lambda x: "goodbye", + ) + + branch.invoke("hello") # "HELLO" + branch.invoke(None) # "goodbye" + """ + + branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] + default: Runnable[Input, Output] + + def __init__( + self, + *branches: Union[ + Tuple[ + Union[ + Runnable[Input, bool], + Callable[[Input], bool], + Callable[[Input], Awaitable[bool]], + ], + RunnableLike, + ], + RunnableLike, # To accommodate the default branch + ], + ) -> None: + """A Runnable that runs one of two branches based on a condition.""" + if len(branches) < 2: + raise ValueError("RunnableBranch requires at least two branches") + + default = branches[-1] + + if not isinstance( + default, + (Runnable, Callable, Mapping), # type: ignore[arg-type] + ): + raise TypeError( + "RunnableBranch default must be runnable, callable or mapping." + ) + + default_ = cast( + Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) + ) + + _branches = [] + + for branch in branches[:-1]: + if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] + raise TypeError( + f"RunnableBranch branches must be " + f"tuples or lists, not {type(branch)}" + ) + + if not len(branch) == 2: + raise ValueError( + f"RunnableBranch branches must be " + f"tuples or lists of length 2, not {len(branch)}" + ) + condition, runnable = branch + condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) + runnable = coerce_to_runnable(runnable) + _branches.append((condition, runnable)) + + super().__init__(branches=_branches, default=default_) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + """RunnableBranch is serializable if all its branches are serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """The namespace of a RunnableBranch is the namespace of its default branch.""" + return cls.__module__.split(".")[:-1] + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + runnables = ( + [self.default] + + [r for _, r in self.branches] + + [r for r, _ in self.branches] + ) + + for runnable in runnables: + if runnable.get_input_schema(config).schema().get("type") is not None: + return runnable.get_input_schema(config) + + return super().get_input_schema(config) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec + for step in ( + [self.default] + + [r for _, r in self.branches] + + [r for r, _ in self.branches] + ) + for spec in step.config_specs + ) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """First evaluates the condition, then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = condition.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = runnable.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = self.default.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """Async version of invoke.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = await condition.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = await runnable.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = await self.default.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py new file mode 100644 index 00000000000..e68b7080f1d --- /dev/null +++ b/libs/core/langchain_core/runnables/config.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from concurrent.futures import Executor, ThreadPoolExecutor +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Generator, + List, + Optional, + Union, + cast, +) + +from typing_extensions import TypedDict + +from langchain_core.runnables.utils import ( + Input, + Output, + accepts_config, + accepts_run_manager, +) + +if TYPE_CHECKING: + from langchain_core.callbacks.base import BaseCallbackManager, Callbacks + from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + ) +else: + # Pydantic validates through typed dicts, but + # the callbacks need forward refs updated + Callbacks = Optional[Union[List, Any]] + + +class EmptyDict(TypedDict, total=False): + """Empty dict type.""" + + pass + + +class RunnableConfig(TypedDict, total=False): + """Configuration for a Runnable.""" + + tags: List[str] + """ + Tags for this call and any sub-calls (eg. a Chain calling an LLM). + You can use these to filter calls. + """ + + metadata: Dict[str, Any] + """ + Metadata for this call and any sub-calls (eg. a Chain calling an LLM). + Keys should be strings, values should be JSON-serializable. + """ + + callbacks: Callbacks + """ + Callbacks for this call and any sub-calls (eg. a Chain calling an LLM). + Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. + """ + + run_name: str + """ + Name for the tracer run for this call. Defaults to the name of the class. + """ + + max_concurrency: Optional[int] + """ + Maximum number of parallel calls to make. If not provided, defaults to + ThreadPoolExecutor's default. + """ + + recursion_limit: int + """ + Maximum number of times a call can recurse. If not provided, defaults to 25. + """ + + configurable: Dict[str, Any] + """ + Runtime values for attributes previously made configurable on this Runnable, + or sub-Runnables, through .configurable_fields() or .configurable_alternatives(). + Check .output_schema() for a description of the attributes that have been made + configurable. + """ + + +def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: + """Ensure that a config is a dict with all keys present. + + Args: + config (Optional[RunnableConfig], optional): The config to ensure. + Defaults to None. + + Returns: + RunnableConfig: The ensured config. + """ + empty = RunnableConfig( + tags=[], + metadata={}, + callbacks=None, + recursion_limit=25, + ) + if config is not None: + empty.update( + cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) + ) + return empty + + +def get_config_list( + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int +) -> List[RunnableConfig]: + """Get a list of configs from a single config or a list of configs. + + It is useful for subclasses overriding batch() or abatch(). + + Args: + config (Optional[Union[RunnableConfig, List[RunnableConfig]]]): + The config or list of configs. + length (int): The length of the list. + + Returns: + List[RunnableConfig]: The list of configs. + + Raises: + ValueError: If the length of the list is not equal to the length of the inputs. + + """ + if length < 0: + raise ValueError(f"length must be >= 0, but got {length}") + if isinstance(config, list) and len(config) != length: + raise ValueError( + f"config must be a list of the same length as inputs, " + f"but got {len(config)} configs for {length} inputs" + ) + + return ( + list(map(ensure_config, config)) + if isinstance(config, list) + else [ensure_config(config) for _ in range(length)] + ) + + +def patch_config( + config: Optional[RunnableConfig], + *, + callbacks: Optional[BaseCallbackManager] = None, + recursion_limit: Optional[int] = None, + max_concurrency: Optional[int] = None, + run_name: Optional[str] = None, + configurable: Optional[Dict[str, Any]] = None, +) -> RunnableConfig: + """Patch a config with new values. + + Args: + config (Optional[RunnableConfig]): The config to patch. + copy_locals (bool, optional): Whether to copy locals. Defaults to False. + callbacks (Optional[BaseCallbackManager], optional): The callbacks to set. + Defaults to None. + recursion_limit (Optional[int], optional): The recursion limit to set. + Defaults to None. + max_concurrency (Optional[int], optional): The max concurrency to set. + Defaults to None. + run_name (Optional[str], optional): The run name to set. Defaults to None. + configurable (Optional[Dict[str, Any]], optional): The configurable to set. + Defaults to None. + + Returns: + RunnableConfig: The patched config. + """ + config = ensure_config(config) + if callbacks is not None: + # If we're replacing callbacks, we need to unset run_name + # As that should apply only to the same run as the original callbacks + config["callbacks"] = callbacks + if "run_name" in config: + del config["run_name"] + if recursion_limit is not None: + config["recursion_limit"] = recursion_limit + if max_concurrency is not None: + config["max_concurrency"] = max_concurrency + if run_name is not None: + config["run_name"] = run_name + if configurable is not None: + config["configurable"] = {**config.get("configurable", {}), **configurable} + return config + + +def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: + """Merge multiple configs into one. + + Args: + *configs (Optional[RunnableConfig]): The configs to merge. + + Returns: + RunnableConfig: The merged config. + """ + base: RunnableConfig = {} + # Even though the keys aren't literals, this is correct + # because both dicts are the same type + for config in (c for c in configs if c is not None): + for key in config: + if key == "metadata": + base[key] = { # type: ignore + **base.get(key, {}), # type: ignore + **(config.get(key) or {}), # type: ignore + } + elif key == "tags": + base[key] = list( # type: ignore + set(base.get(key, []) + (config.get(key) or [])), # type: ignore + ) + elif key == "configurable": + base[key] = { # type: ignore + **base.get(key, {}), # type: ignore + **(config.get(key) or {}), # type: ignore + } + elif key == "callbacks": + base_callbacks = base.get("callbacks") + these_callbacks = config["callbacks"] + # callbacks can be either None, list[handler] or manager + # so merging two callbacks values has 6 cases + if isinstance(these_callbacks, list): + if base_callbacks is None: + base["callbacks"] = these_callbacks + elif isinstance(base_callbacks, list): + base["callbacks"] = base_callbacks + these_callbacks + else: + # base_callbacks is a manager + mngr = base_callbacks.copy() + for callback in these_callbacks: + mngr.add_handler(callback, inherit=True) + base["callbacks"] = mngr + elif these_callbacks is not None: + # these_callbacks is a manager + if base_callbacks is None: + base["callbacks"] = these_callbacks + elif isinstance(base_callbacks, list): + mngr = these_callbacks.copy() + for callback in base_callbacks: + mngr.add_handler(callback, inherit=True) + base["callbacks"] = mngr + else: + # base_callbacks is also a manager + base["callbacks"] = base_callbacks.__class__( + parent_run_id=base_callbacks.parent_run_id + or these_callbacks.parent_run_id, + handlers=base_callbacks.handlers + these_callbacks.handlers, + inheritable_handlers=base_callbacks.inheritable_handlers + + these_callbacks.inheritable_handlers, + tags=list(set(base_callbacks.tags + these_callbacks.tags)), + inheritable_tags=list( + set( + base_callbacks.inheritable_tags + + these_callbacks.inheritable_tags + ) + ), + metadata={ + **base_callbacks.metadata, + **these_callbacks.metadata, + }, + ) + else: + base[key] = config[key] or base.get(key) # type: ignore + return base + + +def call_func_with_variable_args( + func: Union[ + Callable[[Input], Output], + Callable[[Input, RunnableConfig], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], + ], + input: Input, + config: RunnableConfig, + run_manager: Optional[CallbackManagerForChainRun] = None, + **kwargs: Any, +) -> Output: + """Call function that may optionally accept a run_manager and/or config. + + Args: + func (Union[Callable[[Input], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]): + The function to call. + input (Input): The input to the function. + run_manager (CallbackManagerForChainRun): The run manager to + pass to the function. + config (RunnableConfig): The config to pass to the function. + **kwargs (Any): The keyword arguments to pass to the function. + + Returns: + Output: The output of the function. + """ + if accepts_config(func): + if run_manager is not None: + kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) + else: + kwargs["config"] = config + if run_manager is not None and accepts_run_manager(func): + kwargs["run_manager"] = run_manager + return func(input, **kwargs) # type: ignore[call-arg] + + +async def acall_func_with_variable_args( + func: Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ], + input: Input, + config: RunnableConfig, + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + **kwargs: Any, +) -> Output: + """Call function that may optionally accept a run_manager and/or config. + + Args: + func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input, + AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input, + AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]): + The function to call. + input (Input): The input to the function. + run_manager (AsyncCallbackManagerForChainRun): The run manager + to pass to the function. + config (RunnableConfig): The config to pass to the function. + **kwargs (Any): The keyword arguments to pass to the function. + + Returns: + Output: The output of the function. + """ + if accepts_config(func): + if run_manager is not None: + kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) + else: + kwargs["config"] = config + if run_manager is not None and accepts_run_manager(func): + kwargs["run_manager"] = run_manager + return await func(input, **kwargs) # type: ignore[call-arg] + + +def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + """Get a callback manager for a config. + + Args: + config (RunnableConfig): The config. + + Returns: + CallbackManager: The callback manager. + """ + from langchain_core.callbacks.manager import CallbackManager + + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def get_async_callback_manager_for_config( + config: RunnableConfig, +) -> AsyncCallbackManager: + """Get an async callback manager for a config. + + Args: + config (RunnableConfig): The config. + + Returns: + AsyncCallbackManager: The async callback manager. + """ + from langchain_core.callbacks.manager import AsyncCallbackManager + + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +@contextmanager +def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]: + """Get an executor for a config. + + Args: + config (RunnableConfig): The config. + + Yields: + Generator[Executor, None, None]: The executor. + """ + with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: + yield executor diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py new file mode 100644 index 00000000000..7d95d157030 --- /dev/null +++ b/libs/core/langchain_core/runnables/configurable.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import enum +import threading +from abc import abstractmethod +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, + cast, +) +from weakref import WeakValueDictionary + +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables.base import Runnable, RunnableSerializable +from langchain_core.runnables.config import ( + RunnableConfig, + get_config_list, + get_executor_for_config, +) +from langchain_core.runnables.utils import ( + AnyConfigurableField, + ConfigurableField, + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, + ConfigurableFieldSpec, + Input, + Output, + gather_with_concurrency, + get_unique_config_specs, +) + + +class DynamicRunnable(RunnableSerializable[Input, Output]): + """A Serializable Runnable that can be dynamically configured.""" + + default: RunnableSerializable[Input, Output] + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def InputType(self) -> Type[Input]: + return self.default.InputType + + @property + def OutputType(self) -> Type[Output]: + return self.default.OutputType + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self._prepare(config).get_input_schema(config) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self._prepare(config).get_output_schema(config) + + @abstractmethod + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + ... + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return self._prepare(config).invoke(input, config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return await self._prepare(config).ainvoke(input, config, **kwargs) + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + configs = get_config_list(config, len(inputs)) + prepared = [self._prepare(c) for c in configs] + + if all(p is self.default for p in prepared): + return self.default.batch( + inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + def invoke( + bound: Runnable[Input, Output], + input: Input, + config: RunnableConfig, + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return bound.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return bound.invoke(input, config, **kwargs) + + # If there's only one input, don't bother with the executor + if len(inputs) == 1: + return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])]) + + with get_executor_for_config(configs[0]) as executor: + return cast( + List[Output], list(executor.map(invoke, prepared, inputs, configs)) + ) + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + configs = get_config_list(config, len(inputs)) + prepared = [self._prepare(c) for c in configs] + + if all(p is self.default for p in prepared): + return await self.default.abatch( + inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + if not inputs: + return [] + + configs = get_config_list(config, len(inputs)) + + async def ainvoke( + bound: Runnable[Input, Output], + input: Input, + config: RunnableConfig, + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await bound.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await bound.ainvoke(input, config, **kwargs) + + coros = map(ainvoke, prepared, inputs, configs) + return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + return self._prepare(config).stream(input, config, **kwargs) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._prepare(config).astream(input, config, **kwargs): + yield chunk + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + return self._prepare(config).transform(input, config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._prepare(config).atransform(input, config, **kwargs): + yield chunk + + +class RunnableConfigurableFields(DynamicRunnable[Input, Output]): + """A Runnable that can be dynamically configured.""" + + fields: Dict[str, AnyConfigurableField] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + [ + ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description + or self.default.__fields__[field_name].field_info.description, + annotation=spec.annotation + or self.default.__fields__[field_name].annotation, + default=getattr(self.default, field_name), + ) + if isinstance(spec, ConfigurableField) + else make_options_spec( + spec, self.default.__fields__[field_name].field_info.description + ) + for field_name, spec in self.fields.items() + ] + + list(self.default.config_specs) + ) + + def configurable_fields( + self, **kwargs: AnyConfigurableField + ) -> RunnableSerializable[Input, Output]: + return self.default.configurable_fields(**{**self.fields, **kwargs}) + + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + config = config or {} + specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} + configurable_fields = { + specs_by_id[k][0]: v + for k, v in config.get("configurable", {}).items() + if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField) + } + configurable_single_options = { + k: v.options[(config.get("configurable", {}).get(v.id) or v.default)] + for k, v in self.fields.items() + if isinstance(v, ConfigurableFieldSingleOption) + } + configurable_multi_options = { + k: [ + v.options[o] + for o in config.get("configurable", {}).get(v.id, v.default) + ] + for k, v in self.fields.items() + if isinstance(v, ConfigurableFieldMultiOption) + } + configurable = { + **configurable_fields, + **configurable_single_options, + **configurable_multi_options, + } + + if configurable: + return self.default.__class__(**{**self.default.__dict__, **configurable}) + else: + return self.default + + +# Before Python 3.11 native StrEnum is not available +class StrEnum(str, enum.Enum): + """A string enum.""" + + pass + + +_enums_for_spec: WeakValueDictionary[ + Union[ + ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField + ], + Type[StrEnum], +] = WeakValueDictionary() + +_enums_for_spec_lock = threading.Lock() + + +class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): + """A Runnable that can be dynamically configured.""" + + which: ConfigurableField + + alternatives: Dict[ + str, + Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], + ] + + default_key: str = "default" + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + with _enums_for_spec_lock: + if which_enum := _enums_for_spec.get(self.which): + pass + else: + which_enum = StrEnum( # type: ignore[call-overload] + self.which.name or self.which.id, + ( + (v, v) + for v in list(self.alternatives.keys()) + [self.default_key] + ), + ) + _enums_for_spec[self.which] = cast(Type[StrEnum], which_enum) + return [ + ConfigurableFieldSpec( + id=self.which.id, + name=self.which.name, + description=self.which.description, + annotation=which_enum, + default=self.default_key, + ), + *self.default.config_specs, + ] + [ + s + for alt in self.alternatives.values() + if isinstance(alt, RunnableSerializable) + for s in alt.config_specs + ] + + def configurable_fields( + self, **kwargs: AnyConfigurableField + ) -> RunnableSerializable[Input, Output]: + return self.__class__( + which=self.which, + default=self.default.configurable_fields(**kwargs), + alternatives=self.alternatives, + ) + + def _prepare( + self, config: Optional[RunnableConfig] = None + ) -> Runnable[Input, Output]: + config = config or {} + which = config.get("configurable", {}).get(self.which.id, self.default_key) + if which == self.default_key: + return self.default + elif which in self.alternatives: + alt = self.alternatives[which] + if isinstance(alt, Runnable): + return alt + else: + return alt() + else: + raise ValueError(f"Unknown alternative: {which}") + + +def make_options_spec( + spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption], + description: Optional[str], +) -> ConfigurableFieldSpec: + """Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or + ConfigurableFieldMultiOption.""" + with _enums_for_spec_lock: + if enum := _enums_for_spec.get(spec): + pass + else: + enum = StrEnum( # type: ignore[call-overload] + spec.name or spec.id, + ((v, v) for v in list(spec.options.keys())), + ) + _enums_for_spec[spec] = cast(Type[StrEnum], enum) + if isinstance(spec, ConfigurableFieldSingleOption): + return ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description or description, + annotation=enum, + default=spec.default, + ) + else: + return ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description or description, + annotation=Sequence[enum], # type: ignore[valid-type] + default=spec.default, + ) diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py new file mode 100644 index 00000000000..1959b100c8e --- /dev/null +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -0,0 +1,344 @@ +import asyncio +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from langchain_core.load.dump import dumpd +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables.base import Runnable, RunnableSerializable +from langchain_core.runnables.config import ( + RunnableConfig, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, + get_config_list, + patch_config, +) +from langchain_core.runnables.utils import ( + ConfigurableFieldSpec, + Input, + Output, + get_unique_config_specs, +) + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun + + +class RunnableWithFallbacks(RunnableSerializable[Input, Output]): + """A Runnable that can fallback to other Runnables if it fails. + + External APIs (e.g., APIs for a language model) may at times experience + degraded performance or even downtime. + + In these cases, it can be useful to have a fallback runnable that can be + used in place of the original runnable (e.g., fallback to another LLM provider). + + Fallbacks can be defined at the level of a single runnable, or at the level + of a chain of runnables. Fallbacks are tried in order until one succeeds or + all fail. + + While you can instantiate a ``RunnableWithFallbacks`` directly, it is usually + more convenient to use the ``with_fallbacks`` method on a runnable. + + Example: + + .. code-block:: python + + from langchain_core.chat_models.openai import ChatOpenAI + from langchain_core.chat_models.anthropic import ChatAnthropic + + model = ChatAnthropic().with_fallbacks([ChatOpenAI()]) + # Will usually use ChatAnthropic, but fallback to ChatOpenAI + # if ChatAnthropic fails. + model.invoke('hello') + + # And you can also use fallbacks at the level of a chain. + # Here if both LLM providers fail, we'll fallback to a good hardcoded + # response. + + from langchain_core.prompts import PromptTemplate + from langchain_core.schema.output_parser import StrOutputParser + from langchain_core.runnables import RunnableLambda + + def when_all_is_lost(inputs): + return ("Looks like our LLM providers are down. " + "Here's a nice ðŸ¦œï¸ emoji for you instead.") + + chain_with_fallback = ( + PromptTemplate.from_template('Tell me a joke about {topic}') + | model + | StrOutputParser() + ).with_fallbacks([RunnableLambda(when_all_is_lost)]) + """ + + runnable: Runnable[Input, Output] + """The runnable to run first.""" + fallbacks: Sequence[Runnable[Input, Output]] + """A sequence of fallbacks to try.""" + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) + """The exceptions on which fallbacks should be tried. + + Any exception that is not a subclass of these exceptions will be raised immediately. + """ + + class Config: + arbitrary_types_allowed = True + + @property + def InputType(self) -> Type[Input]: + return self.runnable.InputType + + @property + def OutputType(self) -> Type[Output]: + return self.runnable.OutputType + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.runnable.get_input_schema(config) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.runnable.get_output_schema(config) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec + for step in [self.runnable, *self.fallbacks] + for spec in step.config_specs + ) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def runnables(self) -> Iterator[Runnable[Input, Output]]: + yield self.runnable + yield from self.fallbacks + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + first_error = None + for runnable in self.runnables: + try: + output = runnable.invoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + run_manager.on_chain_error(e) + raise e + else: + run_manager.on_chain_end(output) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + run_manager.on_chain_error(first_error) + raise first_error + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + first_error = None + for runnable in self.runnables: + try: + output = await runnable.ainvoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await run_manager.on_chain_error(e) + raise e + else: + await run_manager.on_chain_end(output) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await run_manager.on_chain_error(first_error) + raise first_error + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain_core.callbacks.manager import CallbackManager + + if return_exceptions: + raise NotImplementedError() + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers = [ + cm.on_chain_start( + dumpd(self), + input if isinstance(input, dict) else {"input": input}, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ] + + first_error = None + for runnable in self.runnables: + try: + outputs = runnable.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + return_exceptions=return_exceptions, + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + for rm in run_managers: + rm.on_chain_error(e) + raise e + else: + for rm, output in zip(run_managers, outputs): + rm.on_chain_end(output) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + for rm in run_managers: + rm.on_chain_error(first_error) + raise first_error + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain_core.callbacks.manager import AsyncCallbackManager + + if return_exceptions: + raise NotImplementedError() + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ) + ) + + first_error = None + for runnable in self.runnables: + try: + outputs = await runnable.abatch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + return_exceptions=return_exceptions, + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) + else: + await asyncio.gather( + *( + rm.on_chain_end(output) + for rm, output in zip(run_managers, outputs) + ) + ) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) + raise first_error diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py new file mode 100644 index 00000000000..d8eb5814514 --- /dev/null +++ b/libs/core/langchain_core/runnables/history.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import asyncio +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + Union, +) + +from langchain_core.load import load +from langchain_core.pydantic_v1 import BaseModel, create_model +from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda +from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.runnables.utils import ( + ConfigurableFieldSpec, + get_unique_config_specs, +) +from langchain_core.schema.chat_history import BaseChatMessageHistory + +if TYPE_CHECKING: + from langchain_core.callbacks.tracers.schemas import Run + from langchain_core.runnables.config import RunnableConfig + from langchain_core.schema.messages import BaseMessage + +MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] +GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] + + +class RunnableWithMessageHistory(RunnableBindingBase): + """A runnable that manages chat message history for another runnable. + + Base runnable must have inputs and outputs that can be converted to a list of + BaseMessages. + + RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.: + ``{"configurable": {"session_id": ""}}`` + + Example (dict input): + .. code-block:: python + + from typing import Optional + + from langchain_core.chat_models import ChatAnthropic + from langchain_core.memory.chat_message_histories import RedisChatMessageHistory + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from langchain_core.runnables.history import RunnableWithMessageHistory + + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ]) + + chain = prompt | ChatAnthropic(model="claude-2") + + chain_with_history = RunnableWithMessageHistory( + chain, + RedisChatMessageHistory, + input_messages_key="question", + history_messages_key="history", + ) + + chain_with_history.invoke( + {"ability": "math", "question": "What does cosine mean?"}, + config={"configurable": {"session_id": "foo"}} + ) + # -> "Cosine is ..." + chain_with_history.invoke( + {"ability": "math", "question": "What's its inverse"}, + config={"configurable": {"session_id": "foo"}} + ) + # -> "The inverse of cosine is called arccosine ..." + + """ # noqa: E501 + + get_session_history: GetSessionHistoryCallable + input_messages_key: Optional[str] = None + output_messages_key: Optional[str] = None + history_messages_key: Optional[str] = None + + def __init__( + self, + runnable: Runnable[ + MessagesOrDictWithMessages, + Union[str, BaseMessage, MessagesOrDictWithMessages], + ], + get_session_history: GetSessionHistoryCallable, + *, + input_messages_key: Optional[str] = None, + output_messages_key: Optional[str] = None, + history_messages_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize RunnableWithMessageHistory. + + Args: + runnable: The base Runnable to be wrapped. + + Must take as input one of: + - A sequence of BaseMessages + - A dict with one key for all messages + - A dict with one key for the current input string/message(s) and + a separate key for historical messages. If the input key points + to a string, it will be treated as a HumanMessage in history. + + Must return as output one of: + - A string which can be treated as an AIMessage + - A BaseMessage or sequence of BaseMessages + - A dict with a key for a BaseMessage or sequence of BaseMessages + + get_session_history: Function that returns a new BaseChatMessageHistory + given a session id. Should take a single + positional argument `session_id` which is a string and a named argument + `user_id` which can be a string or None. e.g.: + + ```python + def get_session_history( + session_id: str, + *, + user_id: Optional[str]=None + ) -> BaseChatMessageHistory: + ... + ``` + + input_messages_key: Must be specified if the base runnable accepts a dict + as input. + output_messages_key: Must be specified if the base runnable returns a dict + as output. + history_messages_key: Must be specified if the base runnable accepts a dict + as input and expects a separate key for historical messages. + **kwargs: Arbitrary additional kwargs to pass to parent class + ``RunnableBindingBase`` init. + """ # noqa: E501 + history_chain: Runnable = RunnableLambda( + self._enter_history, self._aenter_history + ).with_config(run_name="load_history") + messages_key = history_messages_key or input_messages_key + if messages_key: + history_chain = RunnablePassthrough.assign( + **{messages_key: history_chain} + ).with_config(run_name="insert_history") + bound = ( + history_chain | runnable.with_listeners(on_end=self._exit_history) + ).with_config(run_name="RunnableWithMessageHistory") + super().__init__( + get_session_history=get_session_history, + input_messages_key=input_messages_key, + output_messages_key=output_messages_key, + bound=bound, + history_messages_key=history_messages_key, + **kwargs, + ) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + super().config_specs + + [ + ConfigurableFieldSpec( + id="session_id", + annotation=str, + name="Session ID", + description="Unique identifier for a session.", + default="", + ), + ] + ) + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + super_schema = super().get_input_schema(config) + if super_schema.__custom_root_type__ is not None: + from langchain_core.schema.messages import BaseMessage + + fields: Dict = {} + if self.input_messages_key and self.history_messages_key: + fields[self.input_messages_key] = ( + Union[str, BaseMessage, Sequence[BaseMessage]], + ..., + ) + elif self.input_messages_key: + fields[self.input_messages_key] = (Sequence[BaseMessage], ...) + else: + fields["__root__"] = (Sequence[BaseMessage], ...) + if self.history_messages_key: + fields[self.history_messages_key] = (Sequence[BaseMessage], ...) + return create_model( # type: ignore[call-overload] + "RunnableWithChatHistoryInput", + **fields, + ) + else: + return super_schema + + def _get_input_messages( + self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]] + ) -> List[BaseMessage]: + from langchain_core.schema.messages import BaseMessage + + if isinstance(input_val, str): + from langchain_core.schema.messages import HumanMessage + + return [HumanMessage(content=input_val)] + elif isinstance(input_val, BaseMessage): + return [input_val] + elif isinstance(input_val, (list, tuple)): + return list(input_val) + else: + raise ValueError( + f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " + f"Got {input_val}." + ) + + def _get_output_messages( + self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] + ) -> List[BaseMessage]: + from langchain_core.schema.messages import BaseMessage + + if isinstance(output_val, dict): + output_val = output_val[self.output_messages_key or "output"] + + if isinstance(output_val, str): + from langchain_core.schema.messages import AIMessage + + return [AIMessage(content=output_val)] + elif isinstance(output_val, BaseMessage): + return [output_val] + elif isinstance(output_val, (list, tuple)): + return list(output_val) + else: + raise ValueError() + + def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: + hist = config["configurable"]["message_history"] + # return only historic messages + if self.history_messages_key: + return hist.messages.copy() + # return all messages + else: + input_val = ( + input if not self.input_messages_key else input[self.input_messages_key] + ) + return hist.messages.copy() + self._get_input_messages(input_val) + + async def _aenter_history( + self, input: Dict[str, Any], config: RunnableConfig + ) -> List[BaseMessage]: + return await asyncio.get_running_loop().run_in_executor( + None, self._enter_history, input, config + ) + + def _exit_history(self, run: Run, config: RunnableConfig) -> None: + hist = config["configurable"]["message_history"] + + # Get the input messages + inputs = load(run.inputs) + input_val = inputs[self.input_messages_key or "input"] + input_messages = self._get_input_messages(input_val) + + # Get the output messages + output_val = load(run.outputs) + output_messages = self._get_output_messages(output_val) + + for m in input_messages + output_messages: + hist.add_message(m) + + def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: + config = super()._merge_configs(*configs) + # extract session_id + if "session_id" not in config.get("configurable", {}): + example_input = {self.input_messages_key: "foo"} + example_config = {"configurable": {"session_id": "123"}} + raise ValueError( + "session_id_id is required." + " Pass it in as part of the config argument to .invoke() or .stream()" + f"\neg. chain.invoke({example_input}, {example_config})" + ) + # attach message_history + session_id = config["configurable"]["session_id"] + config["configurable"]["message_history"] = self.get_session_history(session_id) + return config diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py new file mode 100644 index 00000000000..37d97fb9681 --- /dev/null +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -0,0 +1,453 @@ +"""Implementation of the RunnablePassthrough.""" +from __future__ import annotations + +import asyncio +import inspect +import threading +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Type, + Union, + cast, +) + +from langchain_core.pydantic_v1 import BaseModel, create_model +from langchain_core.runnables.base import ( + Other, + Runnable, + RunnableParallel, + RunnableSerializable, +) +from langchain_core.runnables.config import ( + RunnableConfig, + acall_func_with_variable_args, + call_func_with_variable_args, + get_executor_for_config, +) +from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec +from langchain_core.utils.aiter import atee, py_anext +from langchain_core.utils.iter import safetee + + +def identity(x: Other) -> Other: + """An identity function""" + return x + + +async def aidentity(x: Other) -> Other: + """An async identity function""" + return x + + +class RunnablePassthrough(RunnableSerializable[Other, Other]): + """A runnable to passthrough inputs unchanged or with additional keys. + + This runnable behaves almost like the identity function, except that it + can be configured to add additional keys to the output, if the input is a + dict. + + The examples below demonstrate this runnable works using a few simple + chains. The chains rely on simple lambdas to make the examples easy to execute + and experiment with. + + Examples: + + .. code-block:: python + + from langchain_core.runnables import RunnablePassthrough, RunnableParallel + + runnable = RunnableParallel( + origin=RunnablePassthrough(), + modified=lambda x: x+1 + ) + + runnable.invoke(1) # {'origin': 1, 'modified': 2} + + + def fake_llm(prompt: str) -> str: # Fake LLM for the example + return "completion" + + chain = RunnableLambda(fake_llm) | { + 'original': RunnablePassthrough(), # Original LLM output + 'parsed': lambda text: text[::-1] # Parsing logic + } + + chain.invoke('hello') # {'original': 'completion', 'parsed': 'noitelpmoc'} + + In some cases, it may be useful to pass the input through while adding some + keys to the output. In this case, you can use the `assign` method: + + .. code-block:: python + + from langchain_core.runnables import RunnablePassthrough, RunnableParallel + + def fake_llm(prompt: str) -> str: # Fake LLM for the example + return "completion" + + runnable = { + 'llm1': fake_llm, + 'llm2': fake_llm, + } + | RunnablePassthrough.assign( + total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2']) + ) + + runnable.invoke('hello') + # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} + """ + + input_type: Optional[Type[Other]] = None + + func: Optional[ + Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] + ] = None + + afunc: Optional[ + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ] + ] = None + + def __init__( + self, + func: Optional[ + Union[ + Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]], + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ], + ] + ] = None, + afunc: Optional[ + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ] + ] = None, + *, + input_type: Optional[Type[Other]] = None, + **kwargs: Any, + ) -> None: + if inspect.iscoroutinefunction(func): + afunc = func + func = None + + super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def InputType(self) -> Any: + return self.input_type or Any + + @property + def OutputType(self) -> Any: + return self.input_type or Any + + @classmethod + def assign( + cls, + **kwargs: Union[ + Runnable[Dict[str, Any], Any], + Callable[[Dict[str, Any]], Any], + Mapping[ + str, + Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], + ], + ], + ) -> RunnableAssign: + """Merge the Dict input with the output produced by the mapping argument. + + Args: + mapping: A mapping from keys to runnables or callables. + + Returns: + A runnable that merges the Dict input with the output produced by the + mapping argument. + """ + return RunnableAssign(RunnableParallel(kwargs)) + + def invoke( + self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Other: + if self.func is not None: + call_func_with_variable_args(self.func, input, config or {}, **kwargs) + return self._call_with_config(identity, input, config) + + async def ainvoke( + self, + input: Other, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Other: + if self.afunc is not None: + await acall_func_with_variable_args( + self.afunc, input, config or {}, **kwargs + ) + elif self.func is not None: + call_func_with_variable_args(self.func, input, config or {}, **kwargs) + return await self._acall_with_config(aidentity, input, config) + + def transform( + self, + input: Iterator[Other], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Other]: + if self.func is None: + for chunk in self._transform_stream_with_config(input, identity, config): + yield chunk + else: + final = None + + for chunk in self._transform_stream_with_config(input, identity, config): + yield chunk + if final is None: + final = chunk + else: + final = final + chunk + + if final is not None: + call_func_with_variable_args(self.func, final, config or {}, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Other], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Other]: + if self.afunc is None and self.func is None: + async for chunk in self._atransform_stream_with_config( + input, identity, config + ): + yield chunk + else: + final = None + + async for chunk in self._atransform_stream_with_config( + input, identity, config + ): + yield chunk + if final is None: + final = chunk + else: + final = final + chunk + + if final is not None: + config = config or {} + if self.afunc is not None: + await acall_func_with_variable_args( + self.afunc, final, config, **kwargs + ) + elif self.func is not None: + call_func_with_variable_args(self.func, final, config, **kwargs) + + def stream( + self, + input: Other, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Other]: + return self.transform(iter([input]), config, **kwargs) + + async def astream( + self, + input: Other, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Other]: + async def input_aiter() -> AsyncIterator[Other]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk + + +class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): + """ + A runnable that assigns key-value pairs to Dict[str, Any] inputs. + """ + + mapper: RunnableParallel[Dict[str, Any]] + + def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None: + super().__init__(mapper=mapper, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + map_input_schema = self.mapper.get_input_schema(config) + if not map_input_schema.__custom_root_type__: + # ie. it's a dict + return map_input_schema + + return super().get_input_schema(config) + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + map_input_schema = self.mapper.get_input_schema(config) + map_output_schema = self.mapper.get_output_schema(config) + if ( + not map_input_schema.__custom_root_type__ + and not map_output_schema.__custom_root_type__ + ): + # ie. both are dicts + return create_model( # type: ignore[call-overload] + "RunnableAssignOutput", + **{ + k: (v.type_, v.default) + for s in (map_input_schema, map_output_schema) + for k, v in s.__fields__.items() + }, + ) + elif not map_output_schema.__custom_root_type__: + # ie. only map output is a dict + # ie. input type is either unknown or inferred incorrectly + return map_output_schema + + return super().get_output_schema(config) + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return self.mapper.config_specs + + def invoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + assert isinstance( + input, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + return { + **input, + **self.mapper.invoke(input, config, **kwargs), + } + + async def ainvoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + assert isinstance( + input, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + return { + **input, + **await self.mapper.ainvoke(input, config, **kwargs), + } + + def transform( + self, + input: Iterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + # collect mapper keys + mapper_keys = set(self.mapper.steps.keys()) + # create two streams, one for the map and one for the passthrough + for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) + # create map output stream + map_output = self.mapper.transform(for_map, config, **kwargs) + # get executor to start map output stream in background + with get_executor_for_config(config or {}) as executor: + # start map output stream + first_map_chunk_future = executor.submit( + next, + map_output, # type: ignore + None, + ) + # consume passthrough stream + for chunk in for_passthrough: + assert isinstance( + chunk, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + # remove mapper keys from passthrough chunk, to be overwritten by map + filtered = AddableDict( + {k: v for k, v in chunk.items() if k not in mapper_keys} + ) + if filtered: + yield filtered + # yield map output + yield cast(Dict[str, Any], first_map_chunk_future.result()) + for chunk in map_output: + yield chunk + + async def atransform( + self, + input: AsyncIterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + # collect mapper keys + mapper_keys = set(self.mapper.steps.keys()) + # create two streams, one for the map and one for the passthrough + for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) + # create map output stream + map_output = self.mapper.atransform(for_map, config, **kwargs) + # start map output stream + first_map_chunk_task: asyncio.Task = asyncio.create_task( + py_anext(map_output, None), # type: ignore[arg-type] + ) + # consume passthrough stream + async for chunk in for_passthrough: + assert isinstance( + chunk, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + # remove mapper keys from passthrough chunk, to be overwritten by map output + filtered = AddableDict( + {k: v for k, v in chunk.items() if k not in mapper_keys} + ) + if filtered: + yield filtered + # yield map output + yield await first_map_chunk_task + async for chunk in map_output: + yield chunk + + def stream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + return self.transform(iter([input]), config, **kwargs) + + async def astream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async def input_aiter() -> AsyncIterator[Dict[str, Any]]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py new file mode 100644 index 00000000000..7aeb974648d --- /dev/null +++ b/libs/core/langchain_core/runnables/retry.py @@ -0,0 +1,337 @@ +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from tenacity import ( + AsyncRetrying, + RetryCallState, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from langchain_core.runnables.base import Input, Output, RunnableBindingBase +from langchain_core.runnables.config import RunnableConfig, patch_config + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + + T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) +U = TypeVar("U") + + +class RunnableRetry(RunnableBindingBase[Input, Output]): + """Retry a Runnable if it fails. + + A RunnableRetry helps can be used to add retry logic to any object + that subclasses the base Runnable. + + Such retries are especially useful for network calls that may fail + due to transient errors. + + The RunnableRetry is implemented as a RunnableBinding. The easiest + way to use it is through the `.with_retry()` method on all Runnables. + + Example: + + Here's an example that uses a RunnableLambda to raise an exception + + .. code-block:: python + + import time + + def foo(input) -> None: + '''Fake function that raises an exception.''' + raise ValueError("Invoking foo failed. At time {time.time()}") + + runnable = RunnableLambda(foo) + + runnable_with_retries = runnable.with_retry( + retry_exception_types=(ValueError,), # Retry only on ValueError + wait_exponential_jitter=True, # Add jitter to the exponential backoff + max_attempt_number=2, # Try twice + ) + + # The method invocation above is equivalent to the longer form below: + + runnable_with_retries = RunnableRetry( + bound=runnable, + retry_exception_types=(ValueError,), + max_attempt_number=2, + wait_exponential_jitter=True + ) + + This logic can be used to retry any Runnable, including a chain of Runnables, + but in general it's best practice to keep the scope of the retry as small as + possible. For example, if you have a chain of Runnables, you should only retry + the Runnable that is likely to fail, not the entire chain. + + Example: + + .. code-block:: python + + from langchain_core.chat_models import ChatOpenAI + from langchain_core.prompts import PromptTemplate + + template = PromptTemplate.from_template("tell me a joke about {topic}.") + model = ChatOpenAI(temperature=0.5) + + # Good + chain = template | model.with_retry() + + # Bad + chain = template | model + retryable_chain = chain.with_retry() + """ + + retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) + """The exception types to retry on. By default all exceptions are retried. + + In general you should only retry on exceptions that are likely to be + transient, such as network errors. + + Good exceptions to retry are all server errors (5xx) and selected client + errors (4xx) such as 429 Too Many Requests. + """ + + wait_exponential_jitter: bool = True + """Whether to add jitter to the exponential backoff.""" + + max_attempt_number: int = 3 + """The maximum number of attempts to retry the runnable.""" + + @property + def _kwargs_retrying(self) -> Dict[str, Any]: + kwargs: Dict[str, Any] = dict() + + if self.max_attempt_number: + kwargs["stop"] = stop_after_attempt(self.max_attempt_number) + + if self.wait_exponential_jitter: + kwargs["wait"] = wait_exponential_jitter() + + if self.retry_exception_types: + kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) + + return kwargs + + def _sync_retrying(self, **kwargs: Any) -> Retrying: + return Retrying(**self._kwargs_retrying, **kwargs) + + def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: + return AsyncRetrying(**self._kwargs_retrying, **kwargs) + + def _patch_config( + self, + config: RunnableConfig, + run_manager: "T", + retry_state: RetryCallState, + ) -> RunnableConfig: + attempt = retry_state.attempt_number + tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None + return patch_config(config, callbacks=run_manager.get_child(tag)) + + def _patch_config_list( + self, + config: List[RunnableConfig], + run_manager: List["T"], + retry_state: RetryCallState, + ) -> List[RunnableConfig]: + return [ + self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) + ] + + def _invoke( + self, + input: Input, + run_manager: "CallbackManagerForChainRun", + config: RunnableConfig, + **kwargs: Any, + ) -> Output: + for attempt in self._sync_retrying(reraise=True): + with attempt: + result = super().invoke( + input, + self._patch_config(config, run_manager, attempt.retry_state), + **kwargs, + ) + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return self._call_with_config(self._invoke, input, config, **kwargs) + + async def _ainvoke( + self, + input: Input, + run_manager: "AsyncCallbackManagerForChainRun", + config: RunnableConfig, + **kwargs: Any, + ) -> Output: + async for attempt in self._async_retrying(reraise=True): + with attempt: + result = await super().ainvoke( + input, + self._patch_config(config, run_manager, attempt.retry_state), + **kwargs, + ) + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) + + def _batch( + self, + inputs: List[Input], + run_manager: List["CallbackManagerForChainRun"], + config: List[RunnableConfig], + **kwargs: Any, + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + for attempt in self._sync_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = super().batch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + **kwargs, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[Output]: + return self._batch_with_config( + self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + async def _abatch( + self, + inputs: List[Input], + run_manager: List["AsyncCallbackManagerForChainRun"], + config: List[RunnableConfig], + **kwargs: Any, + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + async for attempt in self._async_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = await super().abatch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + **kwargs, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[Output]: + return await self._abatch_with_config( + self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) + + # stream() and transform() are not retried because retrying a stream + # is not very intuitive. diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py new file mode 100644 index 00000000000..0413d8a9110 --- /dev/null +++ b/libs/core/langchain_core/runnables/router.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from typing import ( + Any, + AsyncIterator, + Callable, + Iterator, + List, + Mapping, + Optional, + Union, + cast, +) + +from typing_extensions import TypedDict + +from langchain_core.runnables.base import ( + Input, + Output, + Runnable, + RunnableSerializable, + coerce_to_runnable, +) +from langchain_core.runnables.config import ( + RunnableConfig, + get_config_list, + get_executor_for_config, +) +from langchain_core.runnables.utils import ( + ConfigurableFieldSpec, + gather_with_concurrency, + get_unique_config_specs, +) + + +class RouterInput(TypedDict): + """A Router input. + + Attributes: + key: The key to route on. + input: The input to pass to the selected runnable. + """ + + key: str + input: Any + + +class RouterRunnable(RunnableSerializable[RouterInput, Output]): + """ + A runnable that routes to a set of runnables based on Input['key']. + Returns the output of the selected runnable. + """ + + runnables: Mapping[str, Runnable[Any, Output]] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for step in self.runnables.values() for spec in step.config_specs + ) + + def __init__( + self, + runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]], + ) -> None: + super().__init__( + runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} + ) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + def invoke( + self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> Output: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + return runnable.invoke(actual_input, config) + + async def ainvoke( + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + return await runnable.ainvoke(actual_input, config) + + def batch( + self, + inputs: List[RouterInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + if not inputs: + return [] + + keys = [input["key"] for input in inputs] + actual_inputs = [input["input"] for input in inputs] + if any(key not in self.runnables for key in keys): + raise ValueError("One or more keys do not have a corresponding runnable") + + def invoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return runnable.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return runnable.invoke(input, config, **kwargs) + + runnables = [self.runnables[key] for key in keys] + configs = get_config_list(config, len(inputs)) + with get_executor_for_config(configs[0]) as executor: + return cast( + List[Output], + list(executor.map(invoke, runnables, actual_inputs, configs)), + ) + + async def abatch( + self, + inputs: List[RouterInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + if not inputs: + return [] + + keys = [input["key"] for input in inputs] + actual_inputs = [input["input"] for input in inputs] + if any(key not in self.runnables for key in keys): + raise ValueError("One or more keys do not have a corresponding runnable") + + async def ainvoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await runnable.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await runnable.ainvoke(input, config, **kwargs) + + runnables = [self.runnables[key] for key in keys] + configs = get_config_list(config, len(inputs)) + return await gather_with_concurrency( + configs[0].get("max_concurrency"), + *( + ainvoke(runnable, input, config) + for runnable, input, config in zip(runnables, actual_inputs, configs) + ), + ) + + def stream( + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + yield from runnable.stream(actual_input, config) + + async def astream( + self, + input: RouterInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + async for output in runnable.astream(actual_input, config): + yield output diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py new file mode 100644 index 00000000000..aafd9d59458 --- /dev/null +++ b/libs/core/langchain_core/runnables/utils.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import ast +import asyncio +import inspect +import textwrap +from inspect import signature +from itertools import groupby +from typing import ( + Any, + AsyncIterable, + Callable, + Coroutine, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + TypeVar, + Union, +) + +Input = TypeVar("Input", contravariant=True) +# Output type should implement __concat__, as eg str, list, dict do +Output = TypeVar("Output", covariant=True) + + +async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: + """Run a coroutine with a semaphore. + Args: + semaphore: The semaphore to use. + coro: The coroutine to run. + + Returns: + The result of the coroutine. + """ + async with semaphore: + return await coro + + +async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: + """Gather coroutines with a limit on the number of concurrent coroutines.""" + if n is None: + return await asyncio.gather(*coros) + + semaphore = asyncio.Semaphore(n) + + return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros)) + + +def accepts_run_manager(callable: Callable[..., Any]) -> bool: + """Check if a callable accepts a run_manager argument.""" + try: + return signature(callable).parameters.get("run_manager") is not None + except ValueError: + return False + + +def accepts_config(callable: Callable[..., Any]) -> bool: + """Check if a callable accepts a config argument.""" + try: + return signature(callable).parameters.get("config") is not None + except ValueError: + return False + + +class IsLocalDict(ast.NodeVisitor): + """Check if a name is a local dict.""" + + def __init__(self, name: str, keys: Set[str]) -> None: + self.name = name + self.keys = keys + + def visit_Subscript(self, node: ast.Subscript) -> Any: + if ( + isinstance(node.ctx, ast.Load) + and isinstance(node.value, ast.Name) + and node.value.id == self.name + and isinstance(node.slice, ast.Constant) + and isinstance(node.slice.value, str) + ): + # we've found a subscript access on the name we're looking for + self.keys.add(node.slice.value) + + def visit_Call(self, node: ast.Call) -> Any: + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == self.name + and node.func.attr == "get" + and len(node.args) in (1, 2) + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): + # we've found a .get() call on the name we're looking for + self.keys.add(node.args[0].value) + + +class IsFunctionArgDict(ast.NodeVisitor): + """Check if the first argument of a function is a dict.""" + + def __init__(self) -> None: + self.keys: Set[str] = set() + + def visit_Lambda(self, node: ast.Lambda) -> Any: + input_arg_name = node.args.args[0].arg + IsLocalDict(input_arg_name, self.keys).visit(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + input_arg_name = node.args.args[0].arg + IsLocalDict(input_arg_name, self.keys).visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: + input_arg_name = node.args.args[0].arg + IsLocalDict(input_arg_name, self.keys).visit(node) + + +class GetLambdaSource(ast.NodeVisitor): + """Get the source code of a lambda function.""" + + def __init__(self) -> None: + """Initialize the visitor.""" + self.source: Optional[str] = None + self.count = 0 + + def visit_Lambda(self, node: ast.Lambda) -> Any: + """Visit a lambda function.""" + self.count += 1 + if hasattr(ast, "unparse"): + self.source = ast.unparse(node) + + +def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: + """Get the keys of the first argument of a function if it is a dict.""" + try: + code = inspect.getsource(func) + tree = ast.parse(textwrap.dedent(code)) + visitor = IsFunctionArgDict() + visitor.visit(tree) + return list(visitor.keys) if visitor.keys else None + except (SyntaxError, TypeError, OSError): + return None + + +def get_lambda_source(func: Callable) -> Optional[str]: + """Get the source code of a lambda function. + + Args: + func: a callable that can be a lambda function + + Returns: + str: the source code of the lambda function + """ + try: + code = inspect.getsource(func) + tree = ast.parse(textwrap.dedent(code)) + visitor = GetLambdaSource() + visitor.visit(tree) + return visitor.source if visitor.count == 1 else None + except (SyntaxError, TypeError, OSError): + return None + + +def indent_lines_after_first(text: str, prefix: str) -> str: + """Indent all lines of text after the first line. + + Args: + text: The text to indent + prefix: Used to determine the number of spaces to indent + + Returns: + str: The indented text + """ + n_spaces = len(prefix) + spaces = " " * n_spaces + lines = text.splitlines() + return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) + + +class AddableDict(Dict[str, Any]): + """ + Dictionary that can be added to another dictionary. + """ + + def __add__(self, other: AddableDict) -> AddableDict: + chunk = AddableDict(self) + for key in other: + if key not in chunk or chunk[key] is None: + chunk[key] = other[key] + elif other[key] is not None: + try: + added = chunk[key] + other[key] + except TypeError: + added = other[key] + chunk[key] = added + return chunk + + def __radd__(self, other: AddableDict) -> AddableDict: + chunk = AddableDict(other) + for key in self: + if key not in chunk or chunk[key] is None: + chunk[key] = self[key] + elif self[key] is not None: + try: + added = chunk[key] + self[key] + except TypeError: + added = self[key] + chunk[key] = added + return chunk + + +_T_co = TypeVar("_T_co", covariant=True) +_T_contra = TypeVar("_T_contra", contravariant=True) + + +class SupportsAdd(Protocol[_T_contra, _T_co]): + """Protocol for objects that support addition.""" + + def __add__(self, __x: _T_contra) -> _T_co: + ... + + +Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any]) + + +def add(addables: Iterable[Addable]) -> Optional[Addable]: + """Add a sequence of addable objects together.""" + final = None + for chunk in addables: + if final is None: + final = chunk + else: + final = final + chunk + return final + + +async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: + """Asynchronously add a sequence of addable objects together.""" + final = None + async for chunk in addables: + if final is None: + final = chunk + else: + final = final + chunk + return final + + +class ConfigurableField(NamedTuple): + """A field that can be configured by the user.""" + + id: str + + name: Optional[str] = None + description: Optional[str] = None + annotation: Optional[Any] = None + + def __hash__(self) -> int: + return hash((self.id, self.annotation)) + + +class ConfigurableFieldSingleOption(NamedTuple): + """A field that can be configured by the user with a default value.""" + + id: str + options: Mapping[str, Any] + default: str + + name: Optional[str] = None + description: Optional[str] = None + + def __hash__(self) -> int: + return hash((self.id, tuple(self.options.keys()), self.default)) + + +class ConfigurableFieldMultiOption(NamedTuple): + """A field that can be configured by the user with multiple default values.""" + + id: str + options: Mapping[str, Any] + default: Sequence[str] + + name: Optional[str] = None + description: Optional[str] = None + + def __hash__(self) -> int: + return hash((self.id, tuple(self.options.keys()), tuple(self.default))) + + +AnyConfigurableField = Union[ + ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption +] + + +class ConfigurableFieldSpec(NamedTuple): + """A field that can be configured by the user. It is a specification of a field.""" + + id: str + name: Optional[str] + description: Optional[str] + + default: Any + annotation: Any + + +def get_unique_config_specs( + specs: Iterable[ConfigurableFieldSpec], +) -> List[ConfigurableFieldSpec]: + """Get the unique config specs from a sequence of config specs.""" + grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) + unique: List[ConfigurableFieldSpec] = [] + for id, dupes in grouped: + first = next(dupes) + others = list(dupes) + if len(others) == 0: + unique.append(first) + elif all(o == first for o in others): + unique.append(first) + else: + raise ValueError( + "RunnableSequence contains conflicting config specs" + f"for {id}: {[first] + others}" + ) + return unique diff --git a/libs/core/langchain_core/schema/__init__.py b/libs/core/langchain_core/schema/__init__.py new file mode 100644 index 00000000000..7e1742cc327 --- /dev/null +++ b/libs/core/langchain_core/schema/__init__.py @@ -0,0 +1,78 @@ +"""**Schemas** are the LangChain Base Classes and Interfaces.""" +from langchain_core.schema.agent import AgentAction, AgentFinish +from langchain_core.schema.cache import BaseCache +from langchain_core.schema.chat_history import BaseChatMessageHistory +from langchain_core.schema.document import BaseDocumentTransformer, Document +from langchain_core.schema.exceptions import LangChainException +from langchain_core.schema.memory import BaseMemory +from langchain_core.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + _message_from_dict, + _message_to_dict, + get_buffer_string, + messages_from_dict, + messages_to_dict, +) +from langchain_core.schema.output import ( + ChatGeneration, + ChatResult, + Generation, + LLMResult, + RunInfo, +) +from langchain_core.schema.output_parser import ( + BaseLLMOutputParser, + BaseOutputParser, + OutputParserException, + StrOutputParser, +) +from langchain_core.schema.prompt import PromptValue +from langchain_core.schema.prompt_template import BasePromptTemplate, format_document +from langchain_core.schema.retriever import BaseRetriever +from langchain_core.schema.storage import BaseStore + +RUN_KEY = "__run" +Memory = BaseMemory + +__all__ = [ + "BaseCache", + "BaseMemory", + "BaseStore", + "AgentFinish", + "AgentAction", + "Document", + "BaseChatMessageHistory", + "BaseDocumentTransformer", + "BaseMessage", + "ChatMessage", + "FunctionMessage", + "HumanMessage", + "AIMessage", + "SystemMessage", + "messages_from_dict", + "messages_to_dict", + "_message_to_dict", + "_message_from_dict", + "get_buffer_string", + "RunInfo", + "LLMResult", + "ChatResult", + "ChatGeneration", + "Generation", + "PromptValue", + "LangChainException", + "BaseRetriever", + "RUN_KEY", + "Memory", + "OutputParserException", + "StrOutputParser", + "BaseOutputParser", + "BaseLLMOutputParser", + "BasePromptTemplate", + "format_document", +] diff --git a/libs/core/langchain_core/schema/agent.py b/libs/core/langchain_core/schema/agent.py new file mode 100644 index 00000000000..94d9d60dc18 --- /dev/null +++ b/libs/core/langchain_core/schema/agent.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Literal, Sequence, Union + +from langchain_core.load.serializable import Serializable +from langchain_core.schema.messages import BaseMessage + + +class AgentAction(Serializable): + """A full description of an action for an ActionAgent to execute.""" + + tool: str + """The name of the Tool to execute.""" + tool_input: Union[str, dict] + """The input to pass in to the Tool.""" + log: str + """Additional information to log about the action. + This log can be used in a few ways. First, it can be used to audit + what exactly the LLM predicted to lead to this (tool, tool_input). + Second, it can be used in future iterations to show the LLMs prior + thoughts. This is useful when (tool, tool_input) does not contain + full information about the LLM prediction (for example, any `thought` + before the tool/tool_input).""" + type: Literal["AgentAction"] = "AgentAction" + + def __init__( + self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any + ): + """Override init to support instantiation by position for backward compat.""" + super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable.""" + return True + + +class AgentActionMessageLog(AgentAction): + message_log: Sequence[BaseMessage] + """Similar to log, this can be used to pass along extra + information about what exact messages were predicted by the LLM + before parsing out the (tool, tool_input). This is again useful + if (tool, tool_input) cannot be used to fully recreate the LLM + prediction, and you need that LLM prediction (for future agent iteration). + Compared to `log`, this is useful when the underlying LLM is a + ChatModel (and therefore returns messages rather than a string).""" + # Ignoring type because we're overriding the type from AgentAction. + # And this is the correct thing to do in this case. + # The type literal is used for serialization purposes. + type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore + + +class AgentFinish(Serializable): + """The final return value of an ActionAgent.""" + + return_values: dict + """Dictionary of return values.""" + log: str + """Additional information to log about the return value. + This is used to pass along the full LLM prediction, not just the parsed out + return value. For example, if the full LLM prediction was + `Final Answer: 2` you may want to just return `2` as a return value, but pass + along the full string as a `log` (for debugging or observability purposes). + """ + type: Literal["AgentFinish"] = "AgentFinish" + + def __init__(self, return_values: dict, log: str, **kwargs: Any): + """Override init to support instantiation by position for backward compat.""" + super().__init__(return_values=return_values, log=log, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable.""" + return True diff --git a/libs/core/langchain_core/schema/cache.py b/libs/core/langchain_core/schema/cache.py new file mode 100644 index 00000000000..fe132c5728d --- /dev/null +++ b/libs/core/langchain_core/schema/cache.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Optional, Sequence + +from langchain_core.schema.output import Generation + +RETURN_VAL_TYPE = Sequence[Generation] + + +class BaseCache(ABC): + """Base interface for cache.""" + + @abstractmethod + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + + @abstractmethod + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + + @abstractmethod + def clear(self, **kwargs: Any) -> None: + """Clear cache that can take additional keyword arguments.""" diff --git a/libs/core/langchain_core/schema/chat.py b/libs/core/langchain_core/schema/chat.py new file mode 100644 index 00000000000..83c0789f1fe --- /dev/null +++ b/libs/core/langchain_core/schema/chat.py @@ -0,0 +1,13 @@ +from typing import Sequence, TypedDict + +from langchain_core.schema import BaseMessage + + +class ChatSession(TypedDict, total=False): + """Chat Session represents a single + conversation, channel, or other group of messages.""" + + messages: Sequence[BaseMessage] + """The LangChain chat messages loaded from the source.""" + functions: Sequence[dict] + """The function calling specs for the messages.""" diff --git a/libs/core/langchain_core/schema/chat_history.py b/libs/core/langchain_core/schema/chat_history.py new file mode 100644 index 00000000000..d3e74e68824 --- /dev/null +++ b/libs/core/langchain_core/schema/chat_history.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage + + +class BaseChatMessageHistory(ABC): + """Abstract base class for storing chat message history. + + See `ChatMessageHistory` for default implementation. + + Example: + .. code-block:: python + + class FileChatMessageHistory(BaseChatMessageHistory): + storage_path: str + session_id: str + + @property + def messages(self): + with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: + messages = json.loads(f.read()) + return messages_from_dict(messages) + + def add_message(self, message: BaseMessage) -> None: + messages = self.messages.append(_message_to_dict(message)) + with open(os.path.join(storage_path, session_id), 'w') as f: + json.dump(f, messages) + + def clear(self): + with open(os.path.join(storage_path, session_id), 'w') as f: + f.write("[]") + """ + + messages: List[BaseMessage] + """A list of Messages stored in-memory.""" + + def add_user_message(self, message: str) -> None: + """Convenience method for adding a human message string to the store. + + Args: + message: The string contents of a human message. + """ + self.add_message(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Convenience method for adding an AI message string to the store. + + Args: + message: The string contents of an AI message. + """ + self.add_message(AIMessage(content=message)) + + @abstractmethod + def add_message(self, message: BaseMessage) -> None: + """Add a Message object to the store. + + Args: + message: A BaseMessage object to store. + """ + raise NotImplementedError() + + @abstractmethod + def clear(self) -> None: + """Remove all messages from the store""" diff --git a/libs/core/langchain_core/schema/document.py b/libs/core/langchain_core/schema/document.py new file mode 100644 index 00000000000..448e31532d1 --- /dev/null +++ b/libs/core/langchain_core/schema/document.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Literal, Sequence + +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import Field + + +class Document(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + type: Literal["Document"] = "Document" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.transform_documents, **kwargs), documents + ) diff --git a/libs/core/langchain_core/schema/embeddings.py b/libs/core/langchain_core/schema/embeddings.py new file mode 100644 index 00000000000..c08a279750b --- /dev/null +++ b/libs/core/langchain_core/schema/embeddings.py @@ -0,0 +1,27 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import List + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_documents, texts + ) + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_query, text + ) diff --git a/libs/core/langchain_core/schema/exceptions.py b/libs/core/langchain_core/schema/exceptions.py new file mode 100644 index 00000000000..27ed0d07dc1 --- /dev/null +++ b/libs/core/langchain_core/schema/exceptions.py @@ -0,0 +1,2 @@ +class LangChainException(Exception): + """General LangChain exception.""" diff --git a/libs/core/langchain_core/schema/language_model.py b/libs/core/langchain_core/schema/language_model.py new file mode 100644 index 00000000000..df22e8b3276 --- /dev/null +++ b/libs/core/langchain_core/schema/language_model.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, +) + +from typing_extensions import TypeAlias + +from langchain_core.runnables import RunnableSerializable +from langchain_core.schema.messages import AnyMessage, BaseMessage, get_buffer_string +from langchain_core.schema.output import LLMResult +from langchain_core.schema.prompt import PromptValue +from langchain_core.utils import get_pydantic_field_names + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import Callbacks + + +@lru_cache(maxsize=None) # Cache the tokenizer +def get_tokenizer() -> Any: + try: + from transformers import GPT2TokenizerFast + except ImportError: + raise ImportError( + "Could not import transformers python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install transformers`." + ) + # create a GPT-2 tokenizer instance + return GPT2TokenizerFast.from_pretrained("gpt2") + + +def _get_token_ids_default_method(text: str) -> List[int]: + """Encode the text into token IDs.""" + # get the cached tokenizer + tokenizer = get_tokenizer() + + # tokenize the text using the GPT-2 tokenizer + return tokenizer.encode(text) + + +LanguageModelInput = Union[PromptValue, str, List[BaseMessage]] +LanguageModelOutput = TypeVar("LanguageModelOutput") + + +class BaseLanguageModel( + RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC +): + """Abstract base class for interfacing with language models. + + All language model wrappers inherit from BaseLanguageModel. + + Exposes three main methods: + - generate_prompt: generate language model outputs for a sequence of prompt + values. A prompt value is a model input that can be converted to any language + model input format (string or messages). + - predict: pass in a single string to a language model and return a string + prediction. + - predict_messages: pass in a sequence of BaseMessages (corresponding to a single + model call) to a language model and return a BaseMessage prediction. + + Each of these has an equivalent asynchronous method. + """ + + @property + def InputType(self) -> TypeAlias: + """Get the input type for this runnable.""" + from langchain_core.prompts.base import StringPromptValue + from langchain_core.prompts.chat import ChatPromptValueConcrete + + # This is a version of LanguageModelInput which replaces the abstract + # base class BaseMessage with a union of its subclasses, which makes + # for a much better schema. + return Union[ + str, + Union[StringPromptValue, ChatPromptValueConcrete], + List[AnyMessage], + ] + + @abstractmethod + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> LLMResult: + """Pass a sequence of prompts to the model and return model generations. + + This method should make use of batched calls for models that expose a batched + API. + + Use this method when you want to: + 1. take advantage of batched calls, + 2. need more output from the model than just the top generated value, + 3. are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + + Args: + prompts: List of PromptValues. A PromptValue is an object that can be + converted to match the format of any language model (string for pure + text generation models and BaseMessages for chat models). + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + callbacks: Callbacks to pass through. Used for executing additional + functionality, such as logging or streaming, throughout generation. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An LLMResult, which contains a list of candidate Generations for each input + prompt and additional model provider-specific output. + """ + + @abstractmethod + async def agenerate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> LLMResult: + """Asynchronously pass a sequence of prompts and return model generations. + + This method should make use of batched calls for models that expose a batched + API. + + Use this method when you want to: + 1. take advantage of batched calls, + 2. need more output from the model than just the top generated value, + 3. are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + + Args: + prompts: List of PromptValues. A PromptValue is an object that can be + converted to match the format of any language model (string for pure + text generation models and BaseMessages for chat models). + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + callbacks: Callbacks to pass through. Used for executing additional + functionality, such as logging or streaming, throughout generation. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An LLMResult, which contains a list of candidate Generations for each input + prompt and additional model provider-specific output. + """ + + @abstractmethod + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + """Pass a single string input to the model and return a string prediction. + + Use this method when passing in raw text. If you want to pass in specific + types of chat messages, use predict_messages. + + Args: + text: String input to pass to the model. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + Top model prediction as a string. + """ + + @abstractmethod + def predict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + """Pass a message sequence to the model and return a message prediction. + + Use this method when passing in chat messages. If you want to pass in raw text, + use predict. + + Args: + messages: A sequence of chat messages corresponding to a single model input. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + Top model prediction as a message. + """ + + @abstractmethod + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: + """Asynchronously pass a string to the model and return a string prediction. + + Use this method when calling pure text generation models and only the top + candidate generation is needed. + + Args: + text: String input to pass to the model. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + Top model prediction as a string. + """ + + @abstractmethod + async def apredict_messages( + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + """Asynchronously pass messages to the model and return a message prediction. + + Use this method when calling chat models and only the top + candidate generation is needed. + + Args: + messages: A sequence of chat messages corresponding to a single model input. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + Top model prediction as a message. + """ + + def get_token_ids(self, text: str) -> List[int]: + """Return the ordered ids of the tokens in a text. + + Args: + text: The string input to tokenize. + + Returns: + A list of ids corresponding to the tokens in the text, in order they occur + in the text. + """ + return _get_token_ids_default_method(text) + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + return len(self.get_token_ids(text)) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input will fit in a model's context window. + + Args: + messages: The message inputs to tokenize. + + Returns: + The sum of the number of tokens across the messages. + """ + return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) + + @classmethod + def _all_required_field_names(cls) -> Set: + """DEPRECATED: Kept for backwards compatibility. + + Use get_pydantic_field_names. + """ + return get_pydantic_field_names(cls) diff --git a/libs/core/langchain_core/schema/memory.py b/libs/core/langchain_core/schema/memory.py new file mode 100644 index 00000000000..0b362661cfd --- /dev/null +++ b/libs/core/langchain_core/schema/memory.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from langchain_core.load.serializable import Serializable + + +class BaseMemory(Serializable, ABC): + """Abstract base class for memory in Chains. + + Memory refers to state in Chains. Memory can be used to store information about + past executions of a Chain and inject that information into the inputs of + future executions of the Chain. For example, for conversational Chains Memory + can be used to store conversations and automatically add them to future model + prompts so that the model has the necessary context to respond coherently to + the latest input. + + Example: + .. code-block:: python + + class SimpleMemory(BaseMemory): + memories: Dict[str, Any] = dict() + + @property + def memory_variables(self) -> List[str]: + return list(self.memories.keys()) + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + return self.memories + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + pass + + def clear(self) -> None: + pass + """ # noqa: E501 + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + @abstractmethod + def memory_variables(self) -> List[str]: + """The string keys this memory class will add to chain inputs.""" + + @abstractmethod + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return key-value pairs given the text input to the chain.""" + + @abstractmethod + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save the context of this chain run to memory.""" + + @abstractmethod + def clear(self) -> None: + """Clear memory contents.""" diff --git a/libs/core/langchain_core/schema/messages.py b/libs/core/langchain_core/schema/messages.py new file mode 100644 index 00000000000..9f96ce68ad9 --- /dev/null +++ b/libs/core/langchain_core/schema/messages.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union + +from typing_extensions import Literal + +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import Extra, Field + +if TYPE_CHECKING: + from langchain_core.prompts.chat import ChatPromptTemplate + + +def get_buffer_string( + messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: + """Convert sequence of Messages to strings and concatenate them into one string. + + Args: + messages: Messages to be converted to strings. + human_prefix: The prefix to prepend to contents of HumanMessages. + ai_prefix: THe prefix to prepend to contents of AIMessages. + + Returns: + A single string concatenation of all input messages. + + Example: + .. code-block:: python + + from langchain_core.schema import AIMessage, HumanMessage + + messages = [ + HumanMessage(content="Hi, how are you?"), + AIMessage(content="Good, how are you?"), + ] + get_buffer_string(messages) + # -> "Human: Hi, how are you?\nAI: Good, how are you?" + """ + string_messages = [] + for m in messages: + if isinstance(m, HumanMessage): + role = human_prefix + elif isinstance(m, AIMessage): + role = ai_prefix + elif isinstance(m, SystemMessage): + role = "System" + elif isinstance(m, FunctionMessage): + role = "Function" + elif isinstance(m, ChatMessage): + role = m.role + else: + raise ValueError(f"Got unsupported message type: {m}") + message = f"{role}: {m.content}" + if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: + message += f"{m.additional_kwargs['function_call']}" + string_messages.append(message) + + return "\n".join(string_messages) + + +class BaseMessage(Serializable): + """The base abstract Message class. + + Messages are the inputs and outputs of ChatModels. + """ + + content: Union[str, List[Union[str, Dict]]] + """The string contents of the message.""" + + additional_kwargs: dict = Field(default_factory=dict) + """Any additional information.""" + + type: str + + class Config: + extra = Extra.allow + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + def __add__(self, other: Any) -> ChatPromptTemplate: + from langchain_core.prompts.chat import ChatPromptTemplate + + prompt = ChatPromptTemplate(messages=[self]) + return prompt + other + + +def merge_content( + first_content: Union[str, List[Union[str, Dict]]], + second_content: Union[str, List[Union[str, Dict]]], +) -> Union[str, List[Union[str, Dict]]]: + # If first chunk is a string + if isinstance(first_content, str): + # If the second chunk is also a string, then merge them naively + if isinstance(second_content, str): + return first_content + second_content + # If the second chunk is a list, add the first chunk to the start of the list + else: + return_list: List[Union[str, Dict]] = [first_content] + return return_list + second_content + # If both are lists, merge them naively + elif isinstance(second_content, List): + return first_content + second_content + # If the first content is a list, and the second content is a string + else: + # If the last element of the first content is a string + # Add the second content to the last element + if isinstance(first_content[-1], str): + return first_content[:-1] + [first_content[-1] + second_content] + else: + # Otherwise, add the second content as a new element of the list + return first_content + [second_content] + + +class BaseMessageChunk(BaseMessage): + """A Message chunk, which can be concatenated with other Message chunks.""" + + def _merge_kwargs_dict( + self, left: Dict[str, Any], right: Dict[str, Any] + ) -> Dict[str, Any]: + """Merge additional_kwargs from another BaseMessageChunk into this one.""" + merged = left.copy() + for k, v in right.items(): + if k not in merged: + merged[k] = v + elif type(merged[k]) != type(v): + raise ValueError( + f'additional_kwargs["{k}"] already exists in this message,' + " but with a different type." + ) + elif isinstance(merged[k], str): + merged[k] += v + elif isinstance(merged[k], dict): + merged[k] = self._merge_kwargs_dict(merged[k], v) + else: + raise ValueError( + f"Additional kwargs key {k} already exists in this message." + ) + return merged + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, BaseMessageChunk): + # If both are (subclasses of) BaseMessageChunk, + # concat into a single BaseMessageChunk + + if isinstance(self, ChatMessageChunk): + return self.__class__( + role=self.role, + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + return self.__class__( + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + else: + raise TypeError( + 'unsupported operand type(s) for +: "' + f"{self.__class__.__name__}" + f'" and "{other.__class__.__name__}"' + ) + + +class HumanMessage(BaseMessage): + """A Message from a human.""" + + example: bool = False + """Whether this Message is being passed in to the model as part of an example + conversation. + """ + + type: Literal["human"] = "human" + + +HumanMessage.update_forward_refs() + + +class HumanMessageChunk(HumanMessage, BaseMessageChunk): + """A Human Message chunk.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501 + + +class AIMessage(BaseMessage): + """A Message from an AI.""" + + example: bool = False + """Whether this Message is being passed in to the model as part of an example + conversation. + """ + + type: Literal["ai"] = "ai" + + +AIMessage.update_forward_refs() + + +class AIMessageChunk(AIMessage, BaseMessageChunk): + """A Message chunk from an AI.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, AIMessageChunk): + if self.example != other.example: + raise ValueError( + "Cannot concatenate AIMessageChunks with different example values." + ) + + return self.__class__( + example=self.example, + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) + + +class SystemMessage(BaseMessage): + """A Message for priming AI behavior, usually passed in as the first of a sequence + of input messages. + """ + + type: Literal["system"] = "system" + + +SystemMessage.update_forward_refs() + + +class SystemMessageChunk(SystemMessage, BaseMessageChunk): + """A System Message chunk.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501 + + +class FunctionMessage(BaseMessage): + """A Message for passing the result of executing a function back to a model.""" + + name: str + """The name of the function that was executed.""" + + type: Literal["function"] = "function" + + +FunctionMessage.update_forward_refs() + + +class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): + """A Function Message chunk.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment] + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, FunctionMessageChunk): + if self.name != other.name: + raise ValueError( + "Cannot concatenate FunctionMessageChunks with different names." + ) + + return self.__class__( + name=self.name, + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) + + +class ToolMessage(BaseMessage): + """A Message for passing the result of executing a tool back to a model.""" + + tool_call_id: str + """Tool call that this message is responding to.""" + + type: Literal["tool"] = "tool" + + +ToolMessage.update_forward_refs() + + +class ToolMessageChunk(ToolMessage, BaseMessageChunk): + """A Tool Message chunk.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, ToolMessageChunk): + if self.tool_call_id != other.tool_call_id: + raise ValueError( + "Cannot concatenate ToolMessageChunks with different names." + ) + + return self.__class__( + tool_call_id=self.tool_call_id, + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) + + +class ChatMessage(BaseMessage): + """A Message that can be assigned an arbitrary speaker (i.e. role).""" + + role: str + """The speaker / role of the Message.""" + + type: Literal["chat"] = "chat" + + +ChatMessage.update_forward_refs() + + +class ChatMessageChunk(ChatMessage, BaseMessageChunk): + """A Chat Message chunk.""" + + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore + + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, ChatMessageChunk): + if self.role != other.role: + raise ValueError( + "Cannot concatenate ChatMessageChunks with different roles." + ) + + return self.__class__( + role=self.role, + content=merge_content(self.content, other.content), + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) + + +AnyMessage = Union[ + AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage +] + + +def _message_to_dict(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: + """Convert a sequence of Messages to a list of dictionaries. + + Args: + messages: Sequence of messages (as BaseMessages) to convert. + + Returns: + List of messages as dicts. + """ + return [_message_to_dict(m) for m in messages] + + +def _message_from_dict(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + elif _type == "function": + return FunctionMessage(**message["data"]) + elif _type == "tool": + return ToolMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected message type: {_type}") + + +def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: + """Convert a sequence of messages from dicts to Message objects. + + Args: + messages: Sequence of messages (as dicts) to convert. + + Returns: + List of messages (BaseMessages). + """ + return [_message_from_dict(m) for m in messages] diff --git a/libs/core/langchain_core/schema/output.py b/libs/core/langchain_core/schema/output.py new file mode 100644 index 00000000000..a4ca64beb4d --- /dev/null +++ b/libs/core/langchain_core/schema/output.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, List, Literal, Optional +from uuid import UUID + +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.messages import BaseMessage, BaseMessageChunk + + +class Generation(Serializable): + """A single text generation output.""" + + text: str + """Generated text output.""" + + generation_info: Optional[Dict[str, Any]] = None + """Raw response from the provider. May include things like the + reason for finishing or token log probabilities. + """ + type: Literal["Generation"] = "Generation" + """Type is used exclusively for serialization purposes.""" + # TODO: add log probs as separate attribute + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + +class GenerationChunk(Generation): + """A Generation chunk, which can be concatenated with other Generation chunks.""" + + def __add__(self, other: GenerationChunk) -> GenerationChunk: + if isinstance(other, GenerationChunk): + generation_info = ( + {**(self.generation_info or {}), **(other.generation_info or {})} + if self.generation_info is not None or other.generation_info is not None + else None + ) + return GenerationChunk( + text=self.text + other.text, + generation_info=generation_info, + ) + else: + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + +class ChatGeneration(Generation): + """A single chat generation output.""" + + text: str = "" + """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" + message: BaseMessage + """The message output by the chat model.""" + # Override type to be ChatGeneration, ignore mypy error as this is intentional + type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment] + """Type is used exclusively for serialization purposes.""" + + @root_validator + def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Set the text attribute to be the contents of the message.""" + try: + values["text"] = values["message"].content + except (KeyError, AttributeError) as e: + raise ValueError("Error while initializing ChatGeneration") from e + return values + + +class ChatGenerationChunk(ChatGeneration): + """A ChatGeneration chunk, which can be concatenated with other + ChatGeneration chunks. + + Attributes: + message: The message chunk output by the chat model. + """ + + message: BaseMessageChunk + # Override type to be ChatGeneration, ignore mypy error as this is intentional + type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501 + """Type is used exclusively for serialization purposes.""" + + def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: + if isinstance(other, ChatGenerationChunk): + generation_info = ( + {**(self.generation_info or {}), **(other.generation_info or {})} + if self.generation_info is not None or other.generation_info is not None + else None + ) + return ChatGenerationChunk( + message=self.message + other.message, + generation_info=generation_info, + ) + else: + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + +class RunInfo(BaseModel): + """Class that contains metadata for a single execution of a Chain or model.""" + + run_id: UUID + """A unique identifier for the model or chain run.""" + + +class ChatResult(BaseModel): + """Class that contains all results for a single chat model call.""" + + generations: List[ChatGeneration] + """List of the chat generations. This is a List because an input can have multiple + candidate generations. + """ + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" + + +class LLMResult(BaseModel): + """Class that contains all results for a batched LLM call.""" + + generations: List[List[Generation]] + """List of generated outputs. This is a List[List[]] because + each input could have multiple candidate generations.""" + llm_output: Optional[dict] = None + """Arbitrary LLM provider-specific output.""" + run: Optional[List[RunInfo]] = None + """List of metadata info for model call for each input.""" + + def flatten(self) -> List[LLMResult]: + """Flatten generations into a single list. + + Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult + contains only a single Generation. If token usage information is available, + it is kept only for the LLMResult corresponding to the top-choice + Generation, to avoid over-counting of token usage downstream. + + Returns: + List of LLMResults where each returned LLMResult contains a single + Generation. + """ + llm_results = [] + for i, gen_list in enumerate(self.generations): + # Avoid double counting tokens in OpenAICallback + if i == 0: + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=self.llm_output, + ) + ) + else: + if self.llm_output is not None: + llm_output = deepcopy(self.llm_output) + llm_output["token_usage"] = dict() + else: + llm_output = None + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=llm_output, + ) + ) + return llm_results + + def __eq__(self, other: object) -> bool: + """Check for LLMResult equality by ignoring any metadata related to runs.""" + if not isinstance(other, LLMResult): + return NotImplemented + return ( + self.generations == other.generations + and self.llm_output == other.llm_output + ) diff --git a/libs/core/langchain_core/schema/output_parser.py b/libs/core/langchain_core/schema/output_parser.py new file mode 100644 index 00000000000..5dd2ebfb2ac --- /dev/null +++ b/libs/core/langchain_core/schema/output_parser.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import asyncio +import functools +from abc import ABC, abstractmethod +from typing import ( + Any, + AsyncIterator, + Dict, + Generic, + Iterator, + List, + Optional, + Type, + TypeVar, + Union, +) + +from typing_extensions import get_args + +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk +from langchain_core.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + Generation, + GenerationChunk, +) +from langchain_core.schema.prompt import PromptValue + +T = TypeVar("T") + + +class BaseLLMOutputParser(Generic[T], ABC): + """Abstract base class for parsing the outputs of a model.""" + + @abstractmethod + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return await asyncio.get_running_loop().run_in_executor( + None, self.parse_result, result + ) + + +class BaseGenerationOutputParser( + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] +): + """Base class to parse the output of an LLM call.""" + + @property + def InputType(self) -> Any: + return Union[str, AnyMessage] + + @property + def OutputType(self) -> Type[T]: + # even though mypy complains this isn't valid, + # it is good enough for pydantic to build the schema from + return T # type: ignore[misc] + + def invoke( + self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None + ) -> T: + if isinstance(input, BaseMessage): + return self._call_with_config( + lambda inner_input: self.parse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return self._call_with_config( + lambda inner_input: self.parse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + + async def ainvoke( + self, + input: str | BaseMessage, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> T: + if isinstance(input, BaseMessage): + return await self._acall_with_config( + lambda inner_input: self.aparse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + + +class BaseOutputParser( + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] +): + """Base class to parse the output of an LLM call. + + Output parsers help structure language model responses. + + Example: + .. code-block:: python + + class BooleanOutputParser(BaseOutputParser[bool]): + true_val: str = "YES" + false_val: str = "NO" + + def parse(self, text: str) -> bool: + cleaned_text = text.strip().upper() + if cleaned_text not in (self.true_val.upper(), self.false_val.upper()): + raise OutputParserException( + f"BooleanOutputParser expected output value to either be " + f"{self.true_val} or {self.false_val} (case-insensitive). " + f"Received {cleaned_text}." + ) + return cleaned_text == self.true_val.upper() + + @property + def _type(self) -> str: + return "boolean_output_parser" + """ # noqa: E501 + + @property + def InputType(self) -> Any: + return Union[str, AnyMessage] + + @property + def OutputType(self) -> Type[T]: + for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] + type_args = get_args(cls) + if type_args and len(type_args) == 1: + return type_args[0] + + raise TypeError( + f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. " + "Override the OutputType property to specify the output type." + ) + + def invoke( + self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None + ) -> T: + if isinstance(input, BaseMessage): + return self._call_with_config( + lambda inner_input: self.parse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return self._call_with_config( + lambda inner_input: self.parse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + + async def ainvoke( + self, + input: str | BaseMessage, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> T: + if isinstance(input, BaseMessage): + return await self._acall_with_config( + lambda inner_input: self.aparse_result( + [ChatGeneration(message=inner_input)] + ), + input, + config, + run_type="parser", + ) + else: + return await self._acall_with_config( + lambda inner_input: self.aparse_result([Generation(text=inner_input)]), + input, + config, + run_type="parser", + ) + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: + """Parse a list of candidate model Generations into a specific format. + + The return value is parsed from only the first Generation in the result, which + is assumed to be the highest-likelihood Generation. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return self.parse(result[0].text) + + @abstractmethod + def parse(self, text: str) -> T: + """Parse a single string model output into some structure. + + Args: + text: String output of a language model. + + Returns: + Structured output. + """ + + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: + """Parse a list of candidate model Generations into a specific format. + + The return value is parsed from only the first Generation in the result, which + is assumed to be the highest-likelihood Generation. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return await asyncio.get_running_loop().run_in_executor( + None, functools.partial(self.parse_result, partial=partial), result + ) + + async def aparse(self, text: str) -> T: + """Parse a single string model output into some structure. + + Args: + text: String output of a language model. + + Returns: + Structured output. + """ + return await asyncio.get_running_loop().run_in_executor(None, self.parse, text) + + # TODO: rename 'completion' -> 'text'. + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + """Parse the output of an LLM call with the input prompt for context. + + The prompt is largely provided in the event the OutputParser wants + to retry or fix the output in some way, and needs information from + the prompt to do so. + + Args: + completion: String output of a language model. + prompt: Input PromptValue. + + Returns: + Structured output + """ + return self.parse(completion) + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the output parser type for serialization.""" + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict(**kwargs) + try: + output_parser_dict["_type"] = self._type + except NotImplementedError: + pass + return output_parser_dict + + +class BaseTransformOutputParser(BaseOutputParser[T]): + """Base class for an output parser that can handle streaming input.""" + + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]: + for chunk in input: + if isinstance(chunk, BaseMessage): + yield self.parse_result([ChatGeneration(message=chunk)]) + else: + yield self.parse_result([Generation(text=chunk)]) + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[T]: + async for chunk in input: + if isinstance(chunk, BaseMessage): + yield self.parse_result([ChatGeneration(message=chunk)]) + else: + yield self.parse_result([Generation(text=chunk)]) + + def transform( + self, + input: Iterator[Union[str, BaseMessage]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[T]: + yield from self._transform_stream_with_config( + input, self._transform, config, run_type="parser" + ) + + async def atransform( + self, + input: AsyncIterator[Union[str, BaseMessage]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[T]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, run_type="parser" + ): + yield chunk + + +class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): + """Base class for an output parser that can handle streaming input.""" + + diff: bool = False + """In streaming mode, whether to yield diffs between the previous and current + parsed output, or just the current parsed output. + """ + + def _diff(self, prev: Optional[T], next: T) -> T: + """Convert parsed outputs into a diff format. The semantics of this are + up to the output parser.""" + raise NotImplementedError() + + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + prev_parsed = None + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen: Generation = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen], partial=True) + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[T]: + prev_parsed = None + acc_gen = None + async for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen: Generation = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen], partial=True) + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed + + +class StrOutputParser(BaseTransformOutputParser[str]): + """OutputParser that parses LLMResult into the top likely string.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @property + def _type(self) -> str: + """Return the output parser type for serialization.""" + return "default" + + def parse(self, text: str) -> str: + """Returns the input text with no changes.""" + return text + + +# TODO: Deprecate +NoOpOutputParser = StrOutputParser + + +class OutputParserException(ValueError): + """Exception that output parsers should raise to signify a parsing error. + + This exists to differentiate parsing errors from other code or execution errors + that also may arise inside the output parser. OutputParserExceptions will be + available to catch and handle in ways to fix the parsing error, while other + errors will be raised. + + Args: + error: The error that's being re-raised or an error message. + observation: String explanation of error which can be passed to a + model to try and remediate the issue. + llm_output: String model output which is error-ing. + send_to_llm: Whether to send the observation and llm_output back to an Agent + after an OutputParserException has been raised. This gives the underlying + model driving the agent the context that the previous output was improperly + structured, in the hopes that it will update the output to the correct + format. + """ + + def __init__( + self, + error: Any, + observation: Optional[str] = None, + llm_output: Optional[str] = None, + send_to_llm: bool = False, + ): + super(OutputParserException, self).__init__(error) + if send_to_llm: + if observation is None or llm_output is None: + raise ValueError( + "Arguments 'observation' & 'llm_output'" + " are required if 'send_to_llm' is True" + ) + self.observation = observation + self.llm_output = llm_output + self.send_to_llm = send_to_llm diff --git a/libs/core/langchain_core/schema/prompt.py b/libs/core/langchain_core/schema/prompt.py new file mode 100644 index 00000000000..f20cfdf4216 --- /dev/null +++ b/libs/core/langchain_core/schema/prompt.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from langchain_core.load.serializable import Serializable +from langchain_core.schema.messages import BaseMessage + + +class PromptValue(Serializable, ABC): + """Base abstract class for inputs to any language model. + + PromptValues can be converted to both LLM (pure text-generation) inputs and + ChatModel inputs. + """ + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @abstractmethod + def to_string(self) -> str: + """Return prompt value as string.""" + + @abstractmethod + def to_messages(self) -> List[BaseMessage]: + """Return prompt as a list of Messages.""" diff --git a/libs/core/langchain_core/schema/prompt_template.py b/libs/core/langchain_core/schema/prompt_template.py new file mode 100644 index 00000000000..2985c5bb06c --- /dev/null +++ b/libs/core/langchain_core/schema/prompt_template.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union + +import yaml + +from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.schema.document import Document +from langchain_core.schema.output_parser import BaseOutputParser +from langchain_core.schema.prompt import PromptValue + + +class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): + """Base class for all prompt templates, returning a prompt.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + input_types: Dict[str, Any] = Field(default_factory=dict) + """A dictionary of the types of the variables the prompt template expects. + If not provided, all variables are assumed to be strings.""" + output_parser: Optional[BaseOutputParser] = None + """How to parse the output of calling an LLM on this formatted prompt.""" + partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( + default_factory=dict + ) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + def OutputType(self) -> Any: + from langchain_core.prompts.base import StringPromptValue + from langchain_core.prompts.chat import ChatPromptValueConcrete + + return Union[StringPromptValue, ChatPromptValueConcrete] + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # This is correct, but pydantic typings/mypy don't think so. + return create_model( # type: ignore[call-overload] + "PromptInput", + **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, + ) + + def invoke( + self, input: Dict, config: Optional[RunnableConfig] = None + ) -> PromptValue: + return self._call_with_config( + lambda inner_input: self.format_prompt( + **{key: inner_input[key] for key in self.input_variables} + ), + input, + config, + run_type="prompt", + ) + + @abstractmethod + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + + @root_validator() + def validate_variable_names(cls, values: Dict) -> Dict: + """Validate variable names do not include restricted names.""" + if "stop" in values["input_variables"]: + raise ValueError( + "Cannot have an input variable named 'stop', as it is used internally," + " please rename." + ) + if "stop" in values["partial_variables"]: + raise ValueError( + "Cannot have an partial variable named 'stop', as it is used " + "internally, please rename." + ) + + overall = set(values["input_variables"]).intersection( + values["partial_variables"] + ) + if overall: + raise ValueError( + f"Found overlapping input and partial variables: {overall}" + ) + return values + + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: + """Return a partial of the prompt template.""" + prompt_dict = self.__dict__.copy() + prompt_dict["input_variables"] = list( + set(self.input_variables).difference(kwargs) + ) + prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} + return type(self)(**prompt_dict) + + def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: + # Get partial params: + partial_kwargs = { + k: v if isinstance(v, str) else v() + for k, v in self.partial_variables.items() + } + return {**partial_kwargs, **kwargs} + + @abstractmethod + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of prompt.""" + prompt_dict = super().dict(**kwargs) + try: + prompt_dict["_type"] = self._prompt_type + except NotImplementedError: + pass + return prompt_dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the prompt. + + Args: + file_path: Path to directory to save prompt to. + + Example: + .. code-block:: python + + prompt.save(file_path="path/prompt.yaml") + """ + if self.partial_variables: + raise ValueError("Cannot save prompt with partial variables.") + + # Fetch dictionary to save + prompt_dict = self.dict() + if "_type" not in prompt_dict: + raise NotImplementedError(f"Prompt {self} does not support saving.") + + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(prompt_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(prompt_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + +def format_document(doc: Document, prompt: BasePromptTemplate) -> str: + """Format a document into a string based on a prompt template. + + First, this pulls information from the document from two sources: + + 1. `page_content`: + This takes the information from the `document.page_content` + and assigns it to a variable named `page_content`. + 2. metadata: + This takes information from `document.metadata` and assigns + it to variables of the same name. + + Those variables are then passed into the `prompt` to produce a formatted string. + + Args: + doc: Document, the page_content and metadata will be used to create + the final string. + prompt: BasePromptTemplate, will be used to format the page_content + and metadata into the final string. + + Returns: + string of the document formatted. + + Example: + .. code-block:: python + + from langchain_core.schema import Document + from langchain_core.prompts import PromptTemplate + + doc = Document(page_content="This is a joke", metadata={"page": "1"}) + prompt = PromptTemplate.from_template("Page {page}: {page_content}") + format_document(doc, prompt) + >>> "Page 1: This is a joke" + """ + base_info = {"page_content": doc.page_content, **doc.metadata} + missing_metadata = set(prompt.input_variables).difference(base_info) + if len(missing_metadata) > 0: + required_metadata = [ + iv for iv in prompt.input_variables if iv != "page_content" + ] + raise ValueError( + f"Document prompt requires documents to have metadata variables: " + f"{required_metadata}. Received document with missing metadata: " + f"{list(missing_metadata)}." + ) + document_info = {k: base_info[k] for k in prompt.input_variables} + return prompt.format(**document_info) diff --git a/libs/core/langchain_core/schema/retriever.py b/libs/core/langchain_core/schema/retriever.py new file mode 100644 index 00000000000..1d8d1778317 --- /dev/null +++ b/libs/core/langchain_core/schema/retriever.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +import warnings +from abc import ABC, abstractmethod +from functools import partial +from inspect import signature +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain_core.load.dump import dumpd +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.schema.document import Document + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, + Callbacks, + ) + + +class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): + """Abstract base class for a Document retrieval system. + + A retrieval system is defined as something that can take string queries and return + the most 'relevant' Documents from some source. + + Example: + .. code-block:: python + + class TFIDFRetriever(BaseRetriever, BaseModel): + vectorizer: Any + docs: List[Document] + tfidf_array: Any + k: int = 4 + + class Config: + arbitrary_types_allowed = True + + def get_relevant_documents(self, query: str) -> List[Document]: + from sklearn.metrics.pairwise import cosine_similarity + + # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) + query_vec = self.vectorizer.transform([query]) + # Op -- (n_docs,1) -- Cosine Sim with each doc + results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) + return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] + """ # noqa: E501 + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + _new_arg_supported: bool = False + _expects_other_args: bool = False + tags: Optional[List[str]] = None + """Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a retriever with its + use case. + """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a retriever with its + use case. + """ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Version upgrade for old retrievers that implemented the public + # methods directly. + if cls.get_relevant_documents != BaseRetriever.get_relevant_documents: + warnings.warn( + "Retrievers must implement abstract `_get_relevant_documents` method" + " instead of `get_relevant_documents`", + DeprecationWarning, + ) + swap = cls.get_relevant_documents + cls.get_relevant_documents = ( # type: ignore[assignment] + BaseRetriever.get_relevant_documents + ) + cls._get_relevant_documents = swap # type: ignore[assignment] + if ( + hasattr(cls, "aget_relevant_documents") + and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents + ): + warnings.warn( + "Retrievers must implement abstract `_aget_relevant_documents` method" + " instead of `aget_relevant_documents`", + DeprecationWarning, + ) + aswap = cls.aget_relevant_documents + cls.aget_relevant_documents = ( # type: ignore[assignment] + BaseRetriever.aget_relevant_documents + ) + cls._aget_relevant_documents = aswap # type: ignore[assignment] + parameters = signature(cls._get_relevant_documents).parameters + cls._new_arg_supported = parameters.get("run_manager") is not None + # If a V1 retriever broke the interface and expects additional arguments + cls._expects_other_args = ( + len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 + ) + + def invoke( + self, input: str, config: Optional[RunnableConfig] = None + ) -> List[Document]: + config = config or {} + return self.get_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + ) + + async def ainvoke( + self, + input: str, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> List[Document]: + config = config or {} + return await self.aget_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + ) + + @abstractmethod + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """Asynchronously get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._get_relevant_documents, run_manager=run_manager), query + ) + + def get_relevant_documents( + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Retrieve documents relevant to a query. + Args: + query: string to find relevant documents for + callbacks: Callback manager or list of callbacks + tags: Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + metadata: Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + Returns: + List of relevant documents + """ + from langchain_core.callbacks.manager import CallbackManager + + callback_manager = CallbackManager.configure( + callbacks, + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=tags, + local_tags=self.tags, + inheritable_metadata=metadata, + local_metadata=self.metadata, + ) + run_manager = callback_manager.on_retriever_start( + dumpd(self), + query, + name=run_name, + **kwargs, + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = self._get_relevant_documents( + query, run_manager=run_manager, **_kwargs + ) + else: + result = self._get_relevant_documents(query, **_kwargs) + except Exception as e: + run_manager.on_retriever_error(e) + raise e + else: + run_manager.on_retriever_end( + result, + **kwargs, + ) + return result + + async def aget_relevant_documents( + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Asynchronously get documents relevant to a query. + Args: + query: string to find relevant documents for + callbacks: Callback manager or list of callbacks + tags: Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + metadata: Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + Returns: + List of relevant documents + """ + from langchain_core.callbacks.manager import AsyncCallbackManager + + callback_manager = AsyncCallbackManager.configure( + callbacks, + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=tags, + local_tags=self.tags, + inheritable_metadata=metadata, + local_metadata=self.metadata, + ) + run_manager = await callback_manager.on_retriever_start( + dumpd(self), + query, + name=run_name, + **kwargs, + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = await self._aget_relevant_documents( + query, run_manager=run_manager, **_kwargs + ) + else: + result = await self._aget_relevant_documents(query, **_kwargs) + except Exception as e: + await run_manager.on_retriever_error(e) + raise e + else: + await run_manager.on_retriever_end( + result, + **kwargs, + ) + return result diff --git a/libs/core/langchain_core/schema/storage.py b/libs/core/langchain_core/schema/storage.py new file mode 100644 index 00000000000..bae5adc2b8e --- /dev/null +++ b/libs/core/langchain_core/schema/storage.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union + +K = TypeVar("K") +V = TypeVar("V") + + +class BaseStore(Generic[K, V], ABC): + """Abstract interface for a key-value store.""" + + @abstractmethod + def mget(self, keys: Sequence[K]) -> List[Optional[V]]: + """Get the values associated with the given keys. + + Args: + keys (Sequence[K]): A sequence of keys. + + Returns: + A sequence of optional values associated with the keys. + If a key is not found, the corresponding value will be None. + """ + + @abstractmethod + def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: + """Set the values for the given keys. + + Args: + key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. + """ + + @abstractmethod + def mdelete(self, keys: Sequence[K]) -> None: + """Delete the given keys and their associated values. + + Args: + keys (Sequence[K]): A sequence of keys to delete. + """ + + @abstractmethod + def yield_keys( + self, *, prefix: Optional[str] = None + ) -> Union[Iterator[K], Iterator[str]]: + """Get an iterator over keys that match the given prefix. + + Args: + prefix (str): The prefix to match. + + Returns: + Iterator[K | str]: An iterator over keys that match the given prefix. + + This method is allowed to return an iterator over either K or str + depending on what makes more sense for the given store. + """ diff --git a/libs/core/langchain_core/schema/vectorstore.py b/libs/core/langchain_core/schema/vectorstore.py new file mode 100644 index 00000000000..078a05d739c --- /dev/null +++ b/libs/core/langchain_core/schema/vectorstore.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import asyncio +import logging +import math +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, +) + +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import BaseRetriever +from langchain_core.schema.document import Document +from langchain_core.schema.embeddings import Embeddings + +if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, + ) + +logger = logging.getLogger(__name__) + +VST = TypeVar("VST", bound="VectorStore") + + +class VectorStore(ABC): + """Interface for vector store.""" + + @abstractmethod + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + + @property + def embeddings(self) -> Optional[Embeddings]: + """Access the query embedding object if available.""" + logger.debug( + f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}" + ) + return None + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete by vector ID or other criteria. + + Args: + ids: List of ids to delete. + **kwargs: Other keyword arguments that subclasses might use. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + + raise NotImplementedError("delete method must be implemented by subclass.") + + async def adelete( + self, ids: Optional[List[str]] = None, **kwargs: Any + ) -> Optional[bool]: + """Delete by vector ID or other criteria. + + Args: + ids: List of ids to delete. + **kwargs: Other keyword arguments that subclasses might use. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + + raise NotImplementedError("delete method must be implemented by subclass.") + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore.""" + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.add_texts, **kwargs), texts, metadatas + ) + + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """Run more documents through the embeddings and add to the vectorstore. + + Args: + documents (List[Document]: Documents to add to the vectorstore. + + Returns: + List[str]: List of IDs of the added texts. + """ + # TODO: Handle the case where the user doesn't provide ids on the Collection + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + return self.add_texts(texts, metadatas, **kwargs) + + async def aadd_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[str]: + """Run more documents through the embeddings and add to the vectorstore. + + Args: + documents (List[Document]: Documents to add to the vectorstore. + + Returns: + List[str]: List of IDs of the added texts. + """ + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + return await self.aadd_texts(texts, metadatas, **kwargs) + + def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: + """Return docs most similar to query using specified search type.""" + if search_type == "similarity": + return self.similarity_search(query, **kwargs) + elif search_type == "mmr": + return self.max_marginal_relevance_search(query, **kwargs) + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity' or 'mmr'." + ) + + async def asearch( + self, query: str, search_type: str, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query using specified search type.""" + if search_type == "similarity": + return await self.asimilarity_search(query, **kwargs) + elif search_type == "mmr": + return await self.amax_marginal_relevance_search(query, **kwargs) + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity' or 'mmr'." + ) + + @abstractmethod + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query.""" + + @staticmethod + def _euclidean_relevance_score_fn(distance: float) -> float: + """Return a similarity score on a scale [0, 1].""" + # The 'correct' relevance function + # may differ depending on a few things, including: + # - the distance / similarity metric used by the VectorStore + # - the scale of your embeddings (OpenAI's are unit normed. Many + # others are not!) + # - embedding dimensionality + # - etc. + # This function converts the euclidean norm of normalized embeddings + # (0 is most similar, sqrt(2) most dissimilar) + # to a similarity function (0 to 1) + return 1.0 - distance / math.sqrt(2) + + @staticmethod + def _cosine_relevance_score_fn(distance: float) -> float: + """Normalize the distance to a score on a scale [0, 1].""" + + return 1.0 - distance + + @staticmethod + def _max_inner_product_relevance_score_fn(distance: float) -> float: + """Normalize the distance to a score on a scale [0, 1].""" + if distance > 0: + return 1.0 - distance + + return -1.0 * distance + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + + Vectorstores should define their own selection based method of relevance. + """ + raise NotImplementedError + + def similarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Run similarity search with distance.""" + raise NotImplementedError + + async def asimilarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Run similarity search with distance asynchronously.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_with_score, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """ + Default similarity search with relevance scores. Modify if necessary + in subclass. + Return docs and relevance scores in the range [0, 1]. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + relevance_score_fn = self._select_relevance_score_fn() + docs_and_scores = self.similarity_search_with_score(query, k, **kwargs) + return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] + + async def _asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """ + Default async similarity search with relevance scores. Modify if necessary + in subclass. + Return docs and relevance scores in the range [0, 1]. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + relevance_score_fn = self._select_relevance_score_fn() + docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs) + return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] + + def similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and relevance scores in the range [0, 1]. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + score_threshold = kwargs.pop("score_threshold", None) + + docs_and_similarities = self._similarity_search_with_relevance_scores( + query, k=k, **kwargs + ) + if any( + similarity < 0.0 or similarity > 1.0 + for _, similarity in docs_and_similarities + ): + warnings.warn( + "Relevance scores must be between" + f" 0 and 1, got {docs_and_similarities}" + ) + + if score_threshold is not None: + docs_and_similarities = [ + (doc, similarity) + for doc, similarity in docs_and_similarities + if similarity >= score_threshold + ] + if len(docs_and_similarities) == 0: + warnings.warn( + "No relevant docs were retrieved using the relevance score" + f" threshold {score_threshold}" + ) + return docs_and_similarities + + async def asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and relevance scores in the range [0, 1], asynchronously. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + score_threshold = kwargs.pop("score_threshold", None) + + docs_and_similarities = await self._asimilarity_search_with_relevance_scores( + query, k=k, **kwargs + ) + if any( + similarity < 0.0 or similarity > 1.0 + for _, similarity in docs_and_similarities + ): + warnings.warn( + "Relevance scores must be between" + f" 0 and 1, got {docs_and_similarities}" + ) + + if score_threshold is not None: + docs_and_similarities = [ + (doc, similarity) + for doc, similarity in docs_and_similarities + if similarity >= score_threshold + ] + if len(docs_and_similarities) == 0: + warnings.warn( + "No relevant docs were retrieved using the relevance score" + f" threshold {score_threshold}" + ) + return docs_and_similarities + + async def asimilarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search, query, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query vector. + """ + raise NotImplementedError + + async def asimilarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to embedding vector.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + raise NotImplementedError + + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial( + self.max_marginal_relevance_search, + query, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + **kwargs, + ) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + raise NotImplementedError + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance.""" + raise NotImplementedError + + @classmethod + def from_documents( + cls: Type[VST], + documents: List[Document], + embedding: Embeddings, + **kwargs: Any, + ) -> VST: + """Return VectorStore initialized from documents and embeddings.""" + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) + + @classmethod + async def afrom_documents( + cls: Type[VST], + documents: List[Document], + embedding: Embeddings, + **kwargs: Any, + ) -> VST: + """Return VectorStore initialized from documents and embeddings.""" + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs) + + @classmethod + @abstractmethod + def from_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> VST: + """Return VectorStore initialized from texts and embeddings.""" + + @classmethod + async def afrom_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> VST: + """Return VectorStore initialized from texts and embeddings.""" + return await asyncio.get_running_loop().run_in_executor( + None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas + ) + + def _get_retriever_tags(self) -> List[str]: + """Get tags for retriever.""" + tags = [self.__class__.__name__] + if self.embeddings: + tags.append(self.embeddings.__class__.__name__) + return tags + + def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: + """Return VectorStoreRetriever initialized from this VectorStore. + + Args: + search_type (Optional[str]): Defines the type of search that + the Retriever should perform. + Can be "similarity" (default), "mmr", or + "similarity_score_threshold". + search_kwargs (Optional[Dict]): Keyword arguments to pass to the + search function. Can include things like: + k: Amount of documents to return (Default: 4) + score_threshold: Minimum relevance threshold + for similarity_score_threshold + fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) + lambda_mult: Diversity of results returned by MMR; + 1 for minimum diversity and 0 for maximum. (Default: 0.5) + filter: Filter by document metadata + + Returns: + VectorStoreRetriever: Retriever class for VectorStore. + + Examples: + + .. code-block:: python + + # Retrieve more documents with higher diversity + # Useful if your dataset has many similar documents + docsearch.as_retriever( + search_type="mmr", + search_kwargs={'k': 6, 'lambda_mult': 0.25} + ) + + # Fetch more documents for the MMR algorithm to consider + # But only return the top 5 + docsearch.as_retriever( + search_type="mmr", + search_kwargs={'k': 5, 'fetch_k': 50} + ) + + # Only retrieve documents that have a relevance score + # Above a certain threshold + docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={'score_threshold': 0.8} + ) + + # Only get the single most similar document from the dataset + docsearch.as_retriever(search_kwargs={'k': 1}) + + # Use a filter to only retrieve documents from a specific paper + docsearch.as_retriever( + search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}} + ) + """ + tags = kwargs.pop("tags", None) or [] + tags.extend(self._get_retriever_tags()) + return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) + + +class VectorStoreRetriever(BaseRetriever): + """Base Retriever class for VectorStore.""" + + vectorstore: VectorStore + """VectorStore to use for retrieval.""" + search_type: str = "similarity" + """Type of search to perform. Defaults to "similarity".""" + search_kwargs: dict = Field(default_factory=dict) + """Keyword arguments to pass to the search function.""" + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + ) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator() + def validate_search_type(cls, values: Dict) -> Dict: + """Validate search type.""" + search_type = values["search_type"] + if search_type not in cls.allowed_search_types: + raise ValueError( + f"search_type of {search_type} not allowed. Valid values are: " + f"{cls.allowed_search_types}" + ) + if search_type == "similarity_score_threshold": + score_threshold = values["search_kwargs"].get("score_threshold") + if (score_threshold is None) or (not isinstance(score_threshold, float)): + raise ValueError( + "`score_threshold` is not specified with a float value(0~1) " + "in `search_kwargs`." + ) + return values + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + return docs + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + if self.search_type == "similarity": + docs = await self.vectorstore.asimilarity_search( + query, **self.search_kwargs + ) + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + await self.vectorstore.asimilarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + docs = await self.vectorstore.amax_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + return docs + + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """Add documents to vectorstore.""" + return self.vectorstore.add_documents(documents, **kwargs) + + async def aadd_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[str]: + """Add documents to vectorstore.""" + return await self.vectorstore.aadd_documents(documents, **kwargs) diff --git a/libs/core/langchain_core/tool.py b/libs/core/langchain_core/tool.py new file mode 100644 index 00000000000..9f362bfa8e3 --- /dev/null +++ b/libs/core/langchain_core/tool.py @@ -0,0 +1,845 @@ +"""Base implementation for tools or skills.""" +from __future__ import annotations + +import asyncio +import inspect +import warnings +from abc import abstractmethod +from functools import partial +from inspect import signature +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union + +from langchain_core.callbacks.base import BaseCallbackManager +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForToolRun, + CallbackManager, + CallbackManagerForToolRun, + Callbacks, +) +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + create_model, + root_validator, + validate_arguments, +) +from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable + + +class SchemaAnnotationError(TypeError): + """Raised when 'args_schema' is missing or has an incorrect type annotation.""" + + +def _create_subset_model( + name: str, model: BaseModel, field_names: list +) -> Type[BaseModel]: + """Create a pydantic model with only a subset of model's fields.""" + fields = {} + for field_name in field_names: + field = model.__fields__[field_name] + fields[field_name] = (field.outer_type_, field.field_info) + return create_model(name, **fields) # type: ignore + + +def _get_filtered_args( + inferred_model: Type[BaseModel], + func: Callable, +) -> dict: + """Get the arguments from a function's signature.""" + schema = inferred_model.schema()["properties"] + valid_keys = signature(func).parameters + return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + + +class _SchemaConfig: + """Configuration for the pydantic model.""" + + extra: Any = Extra.forbid + arbitrary_types_allowed: bool = True + + +def create_schema_from_function( + model_name: str, + func: Callable, +) -> Type[BaseModel]: + """Create a pydantic schema from a function's signature. + Args: + model_name: Name to assign to the generated pydandic schema + func: Function to generate the schema from + Returns: + A pydantic model with the same arguments as the function + """ + # https://docs.pydantic.dev/latest/usage/validation_decorator/ + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + inferred_model = validated.model # type: ignore + if "run_manager" in inferred_model.__fields__: + del inferred_model.__fields__["run_manager"] + if "callbacks" in inferred_model.__fields__: + del inferred_model.__fields__["callbacks"] + # Pydantic adds placeholder virtual fields we need to strip + valid_properties = _get_filtered_args(inferred_model, func) + return _create_subset_model( + f"{model_name}Schema", inferred_model, list(valid_properties) + ) + + +class ToolException(Exception): + """An optional exception that tool throws when execution error occurs. + + When this exception is thrown, the agent will not stop working, + but will handle the exception according to the handle_tool_error + variable of the tool, and the processing result will be returned + to the agent as observation, and printed in red on the console. + """ + + pass + + +class BaseTool(RunnableSerializable[Union[str, Dict], Any]): + """Interface LangChain tools must implement.""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Create the definition of the new tool class.""" + super().__init_subclass__(**kwargs) + + args_schema_type = cls.__annotations__.get("args_schema", None) + + if args_schema_type is not None: + if args_schema_type is None or args_schema_type == BaseModel: + # Throw errors for common mis-annotations. + # TODO: Use get_args / get_origin and fully + # specify valid annotations. + typehint_mandate = """ +class ChildTool(BaseTool): + ... + args_schema: Type[BaseModel] = SchemaClass + ...""" + name = cls.__name__ + raise SchemaAnnotationError( + f"Tool definition for {name} must include valid type annotations" + f" for argument 'args_schema' to behave as expected.\n" + f"Expected annotation of 'Type[BaseModel]'" + f" but got '{args_schema_type}'.\n" + f"Expected class looks like:\n" + f"{typehint_mandate}" + ) + + name: str + """The unique name of the tool that clearly communicates its purpose.""" + description: str + """Used to tell the model how/when/why to use the tool. + + You can provide few-shot examples as a part of the description. + """ + args_schema: Optional[Type[BaseModel]] = None + """Pydantic model class to validate and parse the tool's input arguments.""" + return_direct: bool = False + """Whether to return the tool's output directly. Setting this to True means + + that after the tool is called, the AgentExecutor will stop looping. + """ + verbose: bool = False + """Whether to log the tool's progress.""" + + callbacks: Callbacks = Field(default=None, exclude=True) + """Callbacks to be called during tool execution.""" + callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + """Deprecated. Please use callbacks instead.""" + tags: Optional[List[str]] = None + """Optional list of tags associated with the tool. Defaults to None + These tags will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the tool. Defaults to None + This metadata will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ + + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False + """Handle the content of the ToolException thrown.""" + + class Config(Serializable.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + def is_single_input(self) -> bool: + """Whether the tool only accepts a single input.""" + keys = {k for k in self.args if k != "kwargs"} + return len(keys) == 1 + + @property + def args(self) -> dict: + if self.args_schema is not None: + return self.args_schema.schema()["properties"] + else: + schema = create_schema_from_function(self.name, self._run) + return schema.schema()["properties"] + + # --- Runnable --- + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """The tool's input schema.""" + if self.args_schema is not None: + return self.args_schema + else: + return create_schema_from_function(self.name, self._run) + + def invoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + config = config or {} + return self.run( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + config = config or {} + return await self.arun( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + **kwargs, + ) + + # --- Tool --- + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> Union[str, Dict[str, Any]]: + """Convert tool input to pydantic model.""" + input_args = self.args_schema + if isinstance(tool_input, str): + if input_args is not None: + key_ = next(iter(input_args.__fields__.keys())) + input_args.validate({key_: tool_input}) + return tool_input + else: + if input_args is not None: + result = input_args.parse_obj(tool_input) + return {k: v for k, v in result.dict().items() if k in tool_input} + return tool_input + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ + + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously. + + Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ + return await asyncio.get_running_loop().run_in_executor( + None, + partial(self._run, **kwargs), + *args, + ) + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatibility, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + return (), tool_input + + def run( + self, + tool_input: Union[str, Dict], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run the tool.""" + parsed_input = self._parse_input(tool_input) + if not self.verbose and verbose is not None: + verbose_ = verbose + else: + verbose_ = self.verbose + callback_manager = CallbackManager.configure( + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, + ) + # TODO: maybe also pass through run_manager is _run supports kwargs + new_arg_supported = signature(self._run).parameters.get("run_manager") + run_manager = callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + name=run_name, + **kwargs, + ) + try: + tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + observation = ( + self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else self._run(*tool_args, **tool_kwargs) + ) + except ToolException as e: + if not self.handle_tool_error: + run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation + except (Exception, KeyboardInterrupt) as e: + run_manager.on_tool_error(e) + raise e + else: + run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation + + async def arun( + self, + tool_input: Union[str, Dict], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run the tool asynchronously.""" + parsed_input = self._parse_input(tool_input) + if not self.verbose and verbose is not None: + verbose_ = verbose + else: + verbose_ = self.verbose + callback_manager = AsyncCallbackManager.configure( + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, + ) + new_arg_supported = signature(self._arun).parameters.get("run_manager") + run_manager = await callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + name=run_name, + **kwargs, + ) + try: + # We then call the tool on the tool input to get an observation + tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + observation = ( + await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else await self._arun(*tool_args, **tool_kwargs) + ) + except ToolException as e: + if not self.handle_tool_error: + await run_manager.on_tool_error(e) + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + await run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) + return observation + except (Exception, KeyboardInterrupt) as e: + await run_manager.on_tool_error(e) + raise e + else: + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) + return observation + + def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: + """Make tool callable.""" + return self.run(tool_input, callbacks=callbacks) + + +class Tool(BaseTool): + """Tool that takes in function or coroutine directly.""" + + description: str = "" + func: Optional[Callable[..., str]] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[str]]] = None + """The asynchronous version of the function.""" + + # --- Runnable --- + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + if not self.coroutine: + # If the tool does not implement async, fall back to default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, **kwargs) + ) + + return await super().ainvoke(input, config, **kwargs) + + # --- Tool --- + + @property + def args(self) -> dict: + """The tool's input arguments.""" + if self.args_schema is not None: + return self.args_schema.schema()["properties"] + # For backwards compatibility, if the function signature is ambiguous, + # assume it takes a single string input. + return {"tool_input": {"type": "string"}} + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + """Convert tool input to pydantic model.""" + args, kwargs = super()._to_args_and_kwargs(tool_input) + # For backwards compatibility. The tool must be run with a single input + all_args = list(args) + list(kwargs.values()) + if len(all_args) != 1: + raise ToolException( + f"Too many arguments to single-input tool {self.name}." + f" Args: {all_args}" + ) + return tuple(all_args), {} + + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool.""" + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) + raise NotImplementedError("Tool does not support sync") + + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously.""" + if self.coroutine: + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) + else: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._run, run_manager=run_manager, **kwargs), *args + ) + + # TODO: this is for backwards compatibility, remove in future + def __init__( + self, name: str, func: Optional[Callable], description: str, **kwargs: Any + ) -> None: + """Initialize tool.""" + super(Tool, self).__init__( + name=name, func=func, description=description, **kwargs + ) + + @classmethod + def from_function( + cls, + func: Optional[Callable], + name: str, # We keep these required to support backwards compatibility + description: str, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + coroutine: Optional[ + Callable[..., Awaitable[Any]] + ] = None, # This is last for compatibility, but should be after func + **kwargs: Any, + ) -> Tool: + """Initialize tool from a function.""" + if func is None and coroutine is None: + raise ValueError("Function and/or coroutine must be provided") + return cls( + name=name, + func=func, + coroutine=coroutine, + description=description, + return_direct=return_direct, + args_schema=args_schema, + **kwargs, + ) + + +class StructuredTool(BaseTool): + """Tool that can operate on any number of inputs.""" + + description: str = "" + args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + """The input arguments' schema.""" + func: Optional[Callable[..., Any]] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[Any]]] = None + """The asynchronous version of the function.""" + + # --- Runnable --- + + async def ainvoke( + self, + input: Union[str, Dict], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + if not self.coroutine: + # If the tool does not implement async, fall back to default implementation + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, input, config, **kwargs) + ) + + return await super().ainvoke(input, config, **kwargs) + + # --- Tool --- + + @property + def args(self) -> dict: + """The tool's input arguments.""" + return self.args_schema.schema()["properties"] + + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: + """Use the tool.""" + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) + raise NotImplementedError("Tool does not support sync") + + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: + """Use the tool asynchronously.""" + if self.coroutine: + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) + return await asyncio.get_running_loop().run_in_executor( + None, + partial(self._run, run_manager=run_manager, **kwargs), + *args, + ) + + @classmethod + def from_function( + cls, + func: Optional[Callable] = None, + coroutine: Optional[Callable[..., Awaitable[Any]]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> StructuredTool: + """Create tool from a given function. + + A classmethod that helps to create a tool from a function. + + Args: + func: The function from which to create a tool + coroutine: The async function from which to create a tool + name: The name of the tool. Defaults to the function name + description: The description of the tool. Defaults to the function docstring + return_direct: Whether to return the result directly or as a callback + args_schema: The schema of the tool's input arguments + infer_schema: Whether to infer the schema from the function's signature + **kwargs: Additional arguments to pass to the tool + + Returns: + The tool + + Examples: + + .. code-block:: python + + def add(a: int, b: int) -> int: + \"\"\"Add two numbers\"\"\" + return a + b + tool = StructuredTool.from_function(add) + tool.run(1, 2) # 3 + """ + + if func is not None: + source_function = func + elif coroutine is not None: + source_function = coroutine + else: + raise ValueError("Function and/or coroutine must be provided") + name = name or source_function.__name__ + description = description or source_function.__doc__ + if description is None: + raise ValueError( + "Function must have a docstring if description not provided." + ) + + # Description example: + # search_api(query: str) - Searches the API for the query. + sig = signature(source_function) + description = f"{name}{sig} - {description.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{name}Schema", source_function) + return cls( + name=name, + func=func, + coroutine=coroutine, + args_schema=_args_schema, + description=description, + return_direct=return_direct, + **kwargs, + ) + + +def tool( + *args: Union[str, Callable, Runnable], + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, +) -> Callable: + """Make tools out of functions, can be used with or without arguments. + + Args: + *args: The arguments to the tool. + return_direct: Whether to return directly from the tool rather + than continuing the agent loop. + args_schema: optional argument schema for user to specify + infer_schema: Whether to infer the schema of the arguments from + the function's signature. This also makes the resultant tool + accept a dictionary input to its `run()` function. + + Requires: + - Function must be of type (str) -> str + - Function must have a docstring + + Examples: + .. code-block:: python + + @tool + def search_api(query: str) -> str: + # Searches the API for the query. + return + + @tool("search", return_direct=True) + def search_api(query: str) -> str: + # Searches the API for the query. + return + """ + + def _make_with_name(tool_name: str) -> Callable: + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + if isinstance(dec_func, Runnable): + runnable = dec_func + + if runnable.input_schema.schema().get("type") != "object": + raise ValueError("Runnable must have an object schema.") + + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return runnable.invoke(kwargs, {"callbacks": callbacks}) + + coroutine = ainvoke_wrapper + func = invoke_wrapper + schema: Optional[Type[BaseModel]] = runnable.input_schema + description = repr(runnable) + elif inspect.iscoroutinefunction(dec_func): + coroutine = dec_func + func = None + schema = args_schema + description = None + else: + coroutine = None + func = dec_func + schema = args_schema + description = None + + if infer_schema or args_schema is not None: + return StructuredTool.from_function( + func, + coroutine, + name=tool_name, + description=description, + return_direct=return_direct, + args_schema=schema, + infer_schema=infer_schema, + ) + # If someone doesn't want a schema applied, we must treat it as + # a simple string->string function + if func.__doc__ is None: + raise ValueError( + "Function must have a docstring if " + "description not provided and infer_schema is False." + ) + return Tool( + name=tool_name, + func=func, + description=f"{tool_name} tool", + return_direct=return_direct, + coroutine=coroutine, + ) + + return _make_tool + + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + return _make_with_name(args[0])(args[1]) + elif len(args) == 1 and isinstance(args[0], str): + # if the argument is a string, then we use the string as the tool name + # Example usage: @tool("search", return_direct=True) + return _make_with_name(args[0]) + elif len(args) == 1 and callable(args[0]): + # if the argument is a function, then we use the function name as the tool name + # Example usage: @tool + return _make_with_name(args[0].__name__)(args[0]) + elif len(args) == 0: + # if there are no arguments, then we use the function name as the tool name + # Example usage: @tool(return_direct=True) + def _partial(func: Callable[[str], str]) -> BaseTool: + return _make_with_name(func.__name__)(func) + + return _partial + else: + raise ValueError("Too many arguments for tool decorator") diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py new file mode 100644 index 00000000000..df7a586b8c3 --- /dev/null +++ b/libs/core/langchain_core/utils/__init__.py @@ -0,0 +1,38 @@ +""" +**Utility functions** for LangChain. + +These functions do not depend on any other LangChain module. +""" + +from langchain_core.utils.formatting import StrictFormatter, formatter +from langchain_core.utils.input import ( + get_bolded_text, + get_color_mapping, + get_colored_text, + print_text, +) +from langchain_core.utils.utils import ( + check_package_version, + convert_to_secret_str, + get_pydantic_field_names, + guard_import, + mock_now, + raise_for_status_with_text, + xor_args, +) + +__all__ = [ + "StrictFormatter", + "check_package_version", + "convert_to_secret_str", + "formatter", + "get_bolded_text", + "get_color_mapping", + "get_colored_text", + "get_pydantic_field_names", + "guard_import", + "mock_now", + "print_text", + "raise_for_status_with_text", + "xor_args", +] diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py new file mode 100644 index 00000000000..ca44dee3958 --- /dev/null +++ b/libs/core/langchain_core/utils/aiter.py @@ -0,0 +1,209 @@ +""" +Adapted from +https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py +MIT License +""" + +from collections import deque +from typing import ( + Any, + AsyncContextManager, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Deque, + Generic, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, + overload, +) + +T = TypeVar("T") + +_no_default = object() + + +# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54 +# before 3.10, the builtin anext() was not available +def py_anext( + iterator: AsyncIterator[T], default: Union[T, Any] = _no_default +) -> Awaitable[Union[T, None, Any]]: + """Pure-Python implementation of anext() for testing purposes. + + Closely matches the builtin anext() C implementation. + Can be used to compare the built-in implementation of the inner + coroutines machinery to C-implementation of __anext__() and send() + or throw() on the returned generator. + """ + + try: + __anext__ = cast( + Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__ + ) + except AttributeError: + raise TypeError(f"{iterator!r} is not an async iterator") + + if default is _no_default: + return __anext__(iterator) + + async def anext_impl() -> Union[T, Any]: + try: + # The C code is way more low-level than this, as it implements + # all methods of the iterator protocol. In this implementation + # we're relying on higher-level coroutine concepts, but that's + # exactly what we want -- crosstest pure-Python high-level + # implementation and low-level C anext() iterators. + return await __anext__(iterator) + except StopAsyncIteration: + return default + + return anext_impl() + + +class NoLock: + """Dummy lock that provides the proper interface but no protection""" + + async def __aenter__(self) -> None: + pass + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + return False + + +async def tee_peer( + iterator: AsyncIterator[T], + # the buffer specific to this peer + buffer: Deque[T], + # the buffers of all peers, including our own + peers: List[Deque[T]], + lock: AsyncContextManager[Any], +) -> AsyncGenerator[T, None]: + """An individual iterator of a :py:func:`~.tee`""" + try: + while True: + if not buffer: + async with lock: + # Another peer produced an item while we were waiting for the lock. + # Proceed with the next loop iteration to yield the item. + if buffer: + continue + try: + item = await iterator.__anext__() + except StopAsyncIteration: + break + else: + # Append to all buffers, including our own. We'll fetch our + # item from the buffer again, instead of yielding it directly. + # This ensures the proper item ordering if any of our peers + # are fetching items concurrently. They may have buffered their + # item already. + for peer_buffer in peers: + peer_buffer.append(item) + yield buffer.popleft() + finally: + async with lock: + # this peer is done – remove its buffer + for idx, peer_buffer in enumerate(peers): # pragma: no branch + if peer_buffer is buffer: + peers.pop(idx) + break + # if we are the last peer, try and close the iterator + if not peers and hasattr(iterator, "aclose"): + await iterator.aclose() + + +class Tee(Generic[T]): + """ + Create ``n`` separate asynchronous iterators over ``iterable`` + + This splits a single ``iterable`` into multiple iterators, each providing + the same items in the same order. + All child iterators may advance separately but share the same items + from ``iterable`` -- when the most advanced iterator retrieves an item, + it is buffered until the least advanced iterator has yielded it as well. + A ``tee`` works lazily and can handle an infinite ``iterable``, provided + that all iterators advance. + + .. code-block:: python3 + + async def derivative(sensor_data): + previous, current = a.tee(sensor_data, n=2) + await a.anext(previous) # advance one iterator + return a.map(operator.sub, previous, current) + + Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead + of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked + to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method + immediately closes all children, and it can be used in an ``async with`` context + for the same effect. + + If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* + provide these items. Also, ``tee`` must internally buffer each item until the + last iterator has yielded it; if the most and least advanced iterator differ + by most data, using a :py:class:`list` is more efficient (but not lazy). + + If the underlying iterable is concurrency safe (``anext`` may be awaited + concurrently) the resulting iterators are concurrency safe as well. Otherwise, + the iterators are safe if there is only ever one single "most advanced" iterator. + To enforce sequential use of ``anext``, provide a ``lock`` + - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - + and access is automatically synchronised. + """ + + def __init__( + self, + iterable: AsyncIterator[T], + n: int = 2, + *, + lock: Optional[AsyncContextManager[Any]] = None, + ): + self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist + self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + self._children = tuple( + tee_peer( + iterator=self._iterator, + buffer=buffer, + peers=self._buffers, + lock=lock if lock is not None else NoLock(), + ) + for buffer in self._buffers + ) + + def __len__(self) -> int: + return len(self._children) + + @overload + def __getitem__(self, item: int) -> AsyncIterator[T]: + ... + + @overload + def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: + ... + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: + return self._children[item] + + def __iter__(self) -> Iterator[AsyncIterator[T]]: + yield from self._children + + async def __aenter__(self) -> "Tee[T]": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + await self.aclose() + return False + + async def aclose(self) -> None: + for child in self._children: + await child.aclose() + + +atee = Tee diff --git a/libs/core/langchain_core/utils/formatting.py b/libs/core/langchain_core/utils/formatting.py new file mode 100644 index 00000000000..3b3b597b083 --- /dev/null +++ b/libs/core/langchain_core/utils/formatting.py @@ -0,0 +1,38 @@ +"""Utilities for formatting strings.""" +from string import Formatter +from typing import Any, List, Mapping, Sequence, Union + + +class StrictFormatter(Formatter): + """A subclass of formatter that checks for extra keys.""" + + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: + """Check to see if extra parameters are passed.""" + extra = set(kwargs).difference(used_args) + if extra: + raise KeyError(extra) + + def vformat( + self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] + ) -> str: + """Check that no arguments are provided.""" + if len(args) > 0: + raise ValueError( + "No arguments should be provided, " + "everything should be passed as keyword arguments." + ) + return super().vformat(format_string, args, kwargs) + + def validate_input_variables( + self, format_string: str, input_variables: List[str] + ) -> None: + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + super().format(format_string, **dummy_inputs) + + +formatter = StrictFormatter() diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py new file mode 100644 index 00000000000..8d5ae6cc24f --- /dev/null +++ b/libs/core/langchain_core/utils/input.py @@ -0,0 +1,42 @@ +"""Handle chained inputs.""" +from typing import Dict, List, Optional, TextIO + +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} + + +def get_color_mapping( + items: List[str], excluded_colors: Optional[List] = None +) -> Dict[str, str]: + """Get mapping for items to a support color.""" + colors = list(_TEXT_COLOR_MAPPING.keys()) + if excluded_colors is not None: + colors = [c for c in colors if c not in excluded_colors] + color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} + return color_mapping + + +def get_colored_text(text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" + + +def get_bolded_text(text: str) -> str: + """Get bolded text.""" + return f"\033[1m{text}\033[0m" + + +def print_text( + text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None +) -> None: + """Print text with highlighting and no end characters.""" + text_to_print = get_colored_text(text, color) if color else text + print(text_to_print, end=end, file=file) + if file: + file.flush() # ensure all printed content are written to file diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py new file mode 100644 index 00000000000..60834163c3f --- /dev/null +++ b/libs/core/langchain_core/utils/iter.py @@ -0,0 +1,175 @@ +from collections import deque +from itertools import islice +from typing import ( + Any, + ContextManager, + Deque, + Generator, + Generic, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) + +from typing_extensions import Literal + +T = TypeVar("T") + + +class NoLock: + """Dummy lock that provides the proper interface but no protection""" + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + return False + + +def tee_peer( + iterator: Iterator[T], + # the buffer specific to this peer + buffer: Deque[T], + # the buffers of all peers, including our own + peers: List[Deque[T]], + lock: ContextManager[Any], +) -> Generator[T, None, None]: + """An individual iterator of a :py:func:`~.tee`""" + try: + while True: + if not buffer: + with lock: + # Another peer produced an item while we were waiting for the lock. + # Proceed with the next loop iteration to yield the item. + if buffer: + continue + try: + item = next(iterator) + except StopIteration: + break + else: + # Append to all buffers, including our own. We'll fetch our + # item from the buffer again, instead of yielding it directly. + # This ensures the proper item ordering if any of our peers + # are fetching items concurrently. They may have buffered their + # item already. + for peer_buffer in peers: + peer_buffer.append(item) + yield buffer.popleft() + finally: + with lock: + # this peer is done – remove its buffer + for idx, peer_buffer in enumerate(peers): # pragma: no branch + if peer_buffer is buffer: + peers.pop(idx) + break + # if we are the last peer, try and close the iterator + if not peers and hasattr(iterator, "close"): + iterator.close() + + +class Tee(Generic[T]): + """ + Create ``n`` separate asynchronous iterators over ``iterable`` + + This splits a single ``iterable`` into multiple iterators, each providing + the same items in the same order. + All child iterators may advance separately but share the same items + from ``iterable`` -- when the most advanced iterator retrieves an item, + it is buffered until the least advanced iterator has yielded it as well. + A ``tee`` works lazily and can handle an infinite ``iterable``, provided + that all iterators advance. + + .. code-block:: python3 + + async def derivative(sensor_data): + previous, current = a.tee(sensor_data, n=2) + await a.anext(previous) # advance one iterator + return a.map(operator.sub, previous, current) + + Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead + of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked + to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method + immediately closes all children, and it can be used in an ``async with`` context + for the same effect. + + If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* + provide these items. Also, ``tee`` must internally buffer each item until the + last iterator has yielded it; if the most and least advanced iterator differ + by most data, using a :py:class:`list` is more efficient (but not lazy). + + If the underlying iterable is concurrency safe (``anext`` may be awaited + concurrently) the resulting iterators are concurrency safe as well. Otherwise, + the iterators are safe if there is only ever one single "most advanced" iterator. + To enforce sequential use of ``anext``, provide a ``lock`` + - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - + and access is automatically synchronised. + """ + + def __init__( + self, + iterable: Iterator[T], + n: int = 2, + *, + lock: Optional[ContextManager[Any]] = None, + ): + self._iterator = iter(iterable) + self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + self._children = tuple( + tee_peer( + iterator=self._iterator, + buffer=buffer, + peers=self._buffers, + lock=lock if lock is not None else NoLock(), + ) + for buffer in self._buffers + ) + + def __len__(self) -> int: + return len(self._children) + + @overload + def __getitem__(self, item: int) -> Iterator[T]: + ... + + @overload + def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: + ... + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[Iterator[T], Tuple[Iterator[T], ...]]: + return self._children[item] + + def __iter__(self) -> Iterator[Iterator[T]]: + yield from self._children + + def __enter__(self) -> "Tee[T]": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: + self.close() + return False + + def close(self) -> None: + for child in self._children: + child.close() + + +# Why this is needed https://stackoverflow.com/a/44638570 +safetee = Tee + + +def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: + """Utility batching function.""" + it = iter(iterable) + while True: + chunk = list(islice(it, size)) + if not chunk: + return + yield chunk diff --git a/libs/core/langchain_core/utils/loading.py b/libs/core/langchain_core/utils/loading.py new file mode 100644 index 00000000000..9e3f83ec70f --- /dev/null +++ b/libs/core/langchain_core/utils/loading.py @@ -0,0 +1,54 @@ +"""Utilities for loading configurations from langchain_core-hub.""" + +import os +import re +import tempfile +from pathlib import Path, PurePosixPath +from typing import Any, Callable, Optional, Set, TypeVar, Union +from urllib.parse import urljoin + +import requests + +DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") +URL_BASE = os.environ.get( + "LANGCHAIN_HUB_URL_BASE", + "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", +) +HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") + +T = TypeVar("T") + + +def try_load_from_hub( + path: Union[str, Path], + loader: Callable[[str], T], + valid_prefix: str, + valid_suffixes: Set[str], + **kwargs: Any, +) -> Optional[T]: + """Load configuration from hub. Returns None if path is not a hub path.""" + if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): + return None + ref, remote_path_str = match.groups() + ref = ref[1:] if ref else DEFAULT_REF + remote_path = Path(remote_path_str) + if remote_path.parts[0] != valid_prefix: + return None + if remote_path.suffix[1:] not in valid_suffixes: + raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") + + # Using Path with URLs is not recommended, because on Windows + # the backslash is used as the path separator, which can cause issues + # when working with URLs that use forward slashes as the path separator. + # Instead, use PurePosixPath to ensure that forward slashes are used as the + # path separator, regardless of the operating system. + full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) + + r = requests.get(full_url, timeout=5) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = Path(tmpdirname) / remote_path.name + with open(file, "wb") as f: + f.write(r.content) + return loader(str(file), **kwargs) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py new file mode 100644 index 00000000000..80ddb81fcb9 --- /dev/null +++ b/libs/core/langchain_core/utils/pydantic.py @@ -0,0 +1,14 @@ +"""Utilities for tests.""" + + +def get_pydantic_major_version() -> int: + """Get the major version of Pydantic.""" + try: + import pydantic + + return int(pydantic.__version__.split(".")[0]) + except ImportError: + return 0 + + +PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py new file mode 100644 index 00000000000..9b63ddf3ea6 --- /dev/null +++ b/libs/core/langchain_core/utils/utils.py @@ -0,0 +1,180 @@ +"""Generic utility functions.""" +import contextlib +import datetime +import functools +import importlib +import warnings +from importlib.metadata import version +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union + +from packaging.version import parse +from requests import HTTPError, Response + +from langchain_core.pydantic_v1 import SecretStr + + +def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: + """Validate specified keyword args are mutually exclusive.""" + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + """Validate exactly one arg in each group is not None.""" + counts = [ + sum(1 for arg in arg_group if kwargs.get(arg) is not None) + for arg_group in arg_groups + ] + invalid_groups = [i for i, count in enumerate(counts) if count != 1] + if invalid_groups: + invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] + raise ValueError( + "Exactly one argument in each of the following" + " groups must be defined:" + f" {', '.join(invalid_group_names)}" + ) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def raise_for_status_with_text(response: Response) -> None: + """Raise an error with the response text.""" + try: + response.raise_for_status() + except HTTPError as e: + raise ValueError(response.text) from e + + +@contextlib.contextmanager +def mock_now(dt_value): # type: ignore + """Context manager for mocking out datetime.now() in unit tests. + + Example: + with mock_now(datetime.datetime(2011, 2, 3, 10, 11)): + assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11) + """ + + class MockDateTime(datetime.datetime): + """Mock datetime.datetime.now() with a fixed datetime.""" + + @classmethod + def now(cls): # type: ignore + # Create a copy of dt_value. + return datetime.datetime( + dt_value.year, + dt_value.month, + dt_value.day, + dt_value.hour, + dt_value.minute, + dt_value.second, + dt_value.microsecond, + dt_value.tzinfo, + ) + + real_datetime = datetime.datetime + datetime.datetime = MockDateTime + try: + yield datetime.datetime + finally: + datetime.datetime = real_datetime + + +def guard_import( + module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None +) -> Any: + """Dynamically imports a module and raises a helpful exception if the module is not + installed.""" + try: + module = importlib.import_module(module_name, package) + except ImportError: + raise ImportError( + f"Could not import {module_name} python package. " + f"Please install it with `pip install {pip_name or module_name}`." + ) + return module + + +def check_package_version( + package: str, + lt_version: Optional[str] = None, + lte_version: Optional[str] = None, + gt_version: Optional[str] = None, + gte_version: Optional[str] = None, +) -> None: + """Check the version of a package.""" + imported_version = parse(version(package)) + if lt_version is not None and imported_version >= parse(lt_version): + raise ValueError( + f"Expected {package} version to be < {lt_version}. Received " + f"{imported_version}." + ) + if lte_version is not None and imported_version > parse(lte_version): + raise ValueError( + f"Expected {package} version to be <= {lte_version}. Received " + f"{imported_version}." + ) + if gt_version is not None and imported_version <= parse(gt_version): + raise ValueError( + f"Expected {package} version to be > {gt_version}. Received " + f"{imported_version}." + ) + if gte_version is not None and imported_version < parse(gte_version): + raise ValueError( + f"Expected {package} version to be >= {gte_version}. Received " + f"{imported_version}." + ) + + +def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: + """Get field names, including aliases, for a pydantic class. + + Args: + pydantic_cls: Pydantic class.""" + all_required_field_names = set() + for field in pydantic_cls.__fields__.values(): + all_required_field_names.add(field.name) + if field.has_alias: + all_required_field_names.add(field.alias) + return all_required_field_names + + +def build_extra_kwargs( + extra_kwargs: Dict[str, Any], + values: Dict[str, Any], + all_required_field_names: Set[str], +) -> Dict[str, Any]: + """Build extra kwargs from values and extra_kwargs. + + Args: + extra_kwargs: Extra kwargs passed in by user. + values: Values passed in by user. + all_required_field_names: All required field names for the pydantic class. + """ + for field_name in list(values): + if field_name in extra_kwargs: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra_kwargs[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + return extra_kwargs + + +def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock new file mode 100644 index 00000000000..99e7af1b38b --- /dev/null +++ b/libs/core/poetry.lock @@ -0,0 +1,2689 @@ +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. + +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + +[[package]] +name = "anyio" +version = "4.0.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, + {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.22)"] + +[[package]] +name = "appnope" +version = "0.1.3" +description = "Disable App Nap on macOS >= 10.9" +optional = false +python-versions = "*" +files = [ + {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, + {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, +] + +[[package]] +name = "argon2-cffi" +version = "23.1.0" +description = "Argon2 for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, + {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, +] + +[package.dependencies] +argon2-cffi-bindings = "*" + +[package.extras] +dev = ["argon2-cffi[tests,typing]", "tox (>4)"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"] +tests = ["hypothesis", "pytest"] +typing = ["mypy"] + +[[package]] +name = "argon2-cffi-bindings" +version = "21.2.0" +description = "Low-level CFFI bindings for Argon2" +optional = false +python-versions = ">=3.6" +files = [ + {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"}, + {file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"}, +] + +[package.dependencies] +cffi = ">=1.0.1" + +[package.extras] +dev = ["cogapp", "pre-commit", "pytest", "wheel"] +tests = ["pytest"] + +[[package]] +name = "arrow" +version = "1.3.0" +description = "Better dates & times for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, + {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, +] + +[package.dependencies] +python-dateutil = ">=2.7.0" +types-python-dateutil = ">=2.8.10" + +[package.extras] +doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] +test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] + +[[package]] +name = "asttokens" +version = "2.4.1" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, + {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] +test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] + +[[package]] +name = "async-lru" +version = "2.0.4" +description = "Simple LRU cache for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"}, + {file = "async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[[package]] +name = "attrs" +version = "23.1.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, + {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] + +[[package]] +name = "babel" +version = "2.13.1" +description = "Internationalization utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"}, + {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, +] + +[package.dependencies] +pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} +setuptools = {version = "*", markers = "python_version >= \"3.12\""} + +[package.extras] +dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] + +[[package]] +name = "backcall" +version = "0.2.0" +description = "Specifications for callback functions passed in to an API" +optional = false +python-versions = "*" +files = [ + {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, + {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, +] + +[[package]] +name = "beautifulsoup4" +version = "4.12.2" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, + {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, +] + +[package.dependencies] +soupsieve = ">1.2" + +[package.extras] +html5lib = ["html5lib"] +lxml = ["lxml"] + +[[package]] +name = "bleach" +version = "6.1.0" +description = "An easy safelist-based HTML-sanitizing tool." +optional = false +python-versions = ">=3.8" +files = [ + {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"}, + {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"}, +] + +[package.dependencies] +six = ">=1.9.0" +webencodings = "*" + +[package.extras] +css = ["tinycss2 (>=1.1.0,<1.3)"] + +[[package]] +name = "certifi" +version = "2023.11.17" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, + {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, +] + +[[package]] +name = "cffi" +version = "1.16.0" +description = "Foreign Function Interface for Python calling C code." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, + {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, + {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, + {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, + {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"}, + {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"}, + {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"}, + {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"}, + {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"}, + {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"}, + {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, + {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, + {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, + {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, + {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, + {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"}, + {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"}, + {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"}, + {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"}, + {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"}, + {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"}, + {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"}, + {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"}, + {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"}, + {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, +] + +[package.dependencies] +pycparser = "*" + +[[package]] +name = "charset-normalizer" +version = "3.3.2" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "comm" +version = "0.2.0" +description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +optional = false +python-versions = ">=3.8" +files = [ + {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, + {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, +] + +[package.dependencies] +traitlets = ">=4" + +[package.extras] +test = ["pytest"] + +[[package]] +name = "debugpy" +version = "1.8.0" +description = "An implementation of the Debug Adapter Protocol for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"}, + {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"}, + {file = "debugpy-1.8.0-cp310-cp310-win32.whl", hash = "sha256:a8b7a2fd27cd9f3553ac112f356ad4ca93338feadd8910277aff71ab24d8775f"}, + {file = "debugpy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9de202f5d42e62f932507ee8b21e30d49aae7e46d5b1dd5c908db1d7068637"}, + {file = "debugpy-1.8.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:ef54404365fae8d45cf450d0544ee40cefbcb9cb85ea7afe89a963c27028261e"}, + {file = "debugpy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60009b132c91951354f54363f8ebdf7457aeb150e84abba5ae251b8e9f29a8a6"}, + {file = "debugpy-1.8.0-cp311-cp311-win32.whl", hash = "sha256:8cd0197141eb9e8a4566794550cfdcdb8b3db0818bdf8c49a8e8f8053e56e38b"}, + {file = "debugpy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:a64093656c4c64dc6a438e11d59369875d200bd5abb8f9b26c1f5f723622e153"}, + {file = "debugpy-1.8.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b05a6b503ed520ad58c8dc682749113d2fd9f41ffd45daec16e558ca884008cd"}, + {file = "debugpy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c6fb41c98ec51dd010d7ed650accfd07a87fe5e93eca9d5f584d0578f28f35f"}, + {file = "debugpy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:46ab6780159eeabb43c1495d9c84cf85d62975e48b6ec21ee10c95767c0590aa"}, + {file = "debugpy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:bdc5ef99d14b9c0fcb35351b4fbfc06ac0ee576aeab6b2511702e5a648a2e595"}, + {file = "debugpy-1.8.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:61eab4a4c8b6125d41a34bad4e5fe3d2cc145caecd63c3fe953be4cc53e65bf8"}, + {file = "debugpy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:125b9a637e013f9faac0a3d6a82bd17c8b5d2c875fb6b7e2772c5aba6d082332"}, + {file = "debugpy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:57161629133113c97b387382045649a2b985a348f0c9366e22217c87b68b73c6"}, + {file = "debugpy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:e3412f9faa9ade82aa64a50b602544efcba848c91384e9f93497a458767e6926"}, + {file = "debugpy-1.8.0-py2.py3-none-any.whl", hash = "sha256:9c9b0ac1ce2a42888199df1a1906e45e6f3c9555497643a85e0bf2406e3ffbc4"}, + {file = "debugpy-1.8.0.zip", hash = "sha256:12af2c55b419521e33d5fb21bd022df0b5eb267c3e178f1d374a63a2a6bdccd0"}, +] + +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + +[[package]] +name = "defusedxml" +version = "0.7.1" +description = "XML bomb protection for Python stdlib modules" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, + {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "executing" +version = "2.0.1" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = ">=3.5" +files = [ + {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, + {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, +] + +[package.extras] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] + +[[package]] +name = "fastjsonschema" +version = "2.19.0" +description = "Fastest Python implementation of JSON schema" +optional = false +python-versions = "*" +files = [ + {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, + {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, +] + +[package.extras] +devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] + +[[package]] +name = "fqdn" +version = "1.5.1" +description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" +optional = false +python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" +files = [ + {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, + {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, +] + +[[package]] +name = "freezegun" +version = "1.2.2" +description = "Let your Python tests travel through time" +optional = false +python-versions = ">=3.6" +files = [ + {file = "freezegun-1.2.2-py3-none-any.whl", hash = "sha256:ea1b963b993cb9ea195adbd893a48d573fda951b0da64f60883d7e988b606c9f"}, + {file = "freezegun-1.2.2.tar.gz", hash = "sha256:cd22d1ba06941384410cd967d8a99d5ae2442f57dfafeff2fda5de8dc5c05446"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + +[[package]] +name = "idna" +version = "3.4" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, + {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, +] + +[[package]] +name = "importlib-metadata" +version = "6.8.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, + {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + +[[package]] +name = "importlib-resources" +version = "6.1.1" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, + {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "ipykernel" +version = "6.26.0" +description = "IPython Kernel for Jupyter" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"}, + {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "platform_system == \"Darwin\""} +comm = ">=0.1.1" +debugpy = ">=1.6.5" +ipython = ">=7.23.1" +jupyter-client = ">=6.1.12" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +matplotlib-inline = ">=0.1" +nest-asyncio = "*" +packaging = "*" +psutil = "*" +pyzmq = ">=20" +tornado = ">=6.1" +traitlets = ">=5.4.0" + +[package.extras] +cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"] +pyqt5 = ["pyqt5"] +pyside6 = ["pyside6"] +test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "ipython" +version = "8.12.3" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ipython-8.12.3-py3-none-any.whl", hash = "sha256:b0340d46a933d27c657b211a329d0be23793c36595acf9e6ef4164bc01a1804c"}, + {file = "ipython-8.12.3.tar.gz", hash = "sha256:3910c4b54543c2ad73d06579aa771041b7d5707b033bd488669b4cf544e3b363"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "sys_platform == \"darwin\""} +backcall = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +pickleshare = "*" +prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" +typing-extensions = {version = "*", markers = "python_version < \"3.10\""} + +[package.extras] +all = ["black", "curio", "docrepr", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] + +[[package]] +name = "ipywidgets" +version = "8.1.1" +description = "Jupyter interactive widgets" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, + {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, +] + +[package.dependencies] +comm = ">=0.1.3" +ipython = ">=6.1.0" +jupyterlab-widgets = ">=3.0.9,<3.1.0" +traitlets = ">=4.3.1" +widgetsnbextension = ">=4.0.9,<4.1.0" + +[package.extras] +test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] + +[[package]] +name = "isoduration" +version = "20.11.0" +description = "Operations with ISO 8601 durations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042"}, + {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"}, +] + +[package.dependencies] +arrow = ">=0.15.0" + +[[package]] +name = "jedi" +version = "0.19.1" +description = "An autocompletion tool for Python that can be used for text editors." +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, + {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, +] + +[package.dependencies] +parso = ">=0.8.3,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + +[[package]] +name = "jinja2" +version = "3.1.2" +description = "A very fast and expressive template engine." +optional = false +python-versions = ">=3.7" +files = [ + {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, + {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "json5" +version = "0.9.14" +description = "A Python implementation of the JSON5 data format." +optional = false +python-versions = "*" +files = [ + {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"}, + {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"}, +] + +[package.extras] +dev = ["hypothesis"] + +[[package]] +name = "jsonpatch" +version = "1.33" +description = "Apply JSON-Patches (RFC 6902)" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +files = [ + {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, + {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, +] + +[package.dependencies] +jsonpointer = ">=1.9" + +[[package]] +name = "jsonpointer" +version = "2.4" +description = "Identify specific nodes in a JSON document (RFC 6901)" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +files = [ + {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, +] + +[[package]] +name = "jsonschema" +version = "4.20.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, + {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} +idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} +jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""} +jsonschema-specifications = ">=2023.03.6" +pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""} +referencing = ">=0.28.4" +rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} +rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""} +rpds-py = ">=0.7.1" +uri-template = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} +webcolors = {version = ">=1.11", optional = true, markers = "extra == \"format-nongpl\""} + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.11.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.11.1-py3-none-any.whl", hash = "sha256:f596778ab612b3fd29f72ea0d990393d0540a5aab18bf0407a46632eab540779"}, + {file = "jsonschema_specifications-2023.11.1.tar.gz", hash = "sha256:c9b234904ffe02f079bf91b14d79987faa685fd4b39c377a0996954c0090b9ca"}, +] + +[package.dependencies] +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +referencing = ">=0.31.0" + +[[package]] +name = "jupyter" +version = "1.0.0" +description = "Jupyter metapackage. Install all the Jupyter components in one go." +optional = false +python-versions = "*" +files = [ + {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, + {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, + {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"}, +] + +[package.dependencies] +ipykernel = "*" +ipywidgets = "*" +jupyter-console = "*" +nbconvert = "*" +notebook = "*" +qtconsole = "*" + +[[package]] +name = "jupyter-client" +version = "8.6.0" +description = "Jupyter protocol implementation and client libraries" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, + {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +python-dateutil = ">=2.8.2" +pyzmq = ">=23.0" +tornado = ">=6.2" +traitlets = ">=5.3" + +[package.extras] +docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] + +[[package]] +name = "jupyter-console" +version = "6.6.3" +description = "Jupyter terminal console" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, + {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, +] + +[package.dependencies] +ipykernel = ">=6.14" +ipython = "*" +jupyter-client = ">=7.0.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +prompt-toolkit = ">=3.0.30" +pygments = "*" +pyzmq = ">=17" +traitlets = ">=5.4" + +[package.extras] +test = ["flaky", "pexpect", "pytest"] + +[[package]] +name = "jupyter-core" +version = "5.5.0" +description = "Jupyter core package. A base package on which Jupyter projects rely." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, + {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, +] + +[package.dependencies] +platformdirs = ">=2.5" +pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""} +traitlets = ">=5.3" + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] +test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "jupyter-events" +version = "0.9.0" +description = "Jupyter Event System library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, + {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, +] + +[package.dependencies] +jsonschema = {version = ">=4.18.0", extras = ["format-nongpl"]} +python-json-logger = ">=2.0.4" +pyyaml = ">=5.3" +referencing = "*" +rfc3339-validator = "*" +rfc3986-validator = ">=0.1.1" +traitlets = ">=5.3" + +[package.extras] +cli = ["click", "rich"] +docs = ["jupyterlite-sphinx", "myst-parser", "pydata-sphinx-theme", "sphinxcontrib-spelling"] +test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "pytest-console-scripts", "rich"] + +[[package]] +name = "jupyter-lsp" +version = "2.2.0" +description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter-lsp-2.2.0.tar.gz", hash = "sha256:8ebbcb533adb41e5d635eb8fe82956b0aafbf0fd443b6c4bfa906edeeb8635a1"}, + {file = "jupyter_lsp-2.2.0-py3-none-any.whl", hash = "sha256:9e06b8b4f7dd50300b70dd1a78c0c3b0c3d8fa68e0f2d8a5d1fbab62072aca3f"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} +jupyter-server = ">=1.1.2" + +[[package]] +name = "jupyter-server" +version = "2.10.1" +description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_server-2.10.1-py3-none-any.whl", hash = "sha256:20519e355d951fc5e1b6ac5952854fe7620d0cfb56588fa4efe362a758977ed3"}, + {file = "jupyter_server-2.10.1.tar.gz", hash = "sha256:e6da2657a954a7879eed28cc08e0817b01ffd81d7eab8634660397b55f926472"}, +] + +[package.dependencies] +anyio = ">=3.1.0" +argon2-cffi = "*" +jinja2 = "*" +jupyter-client = ">=7.4.4" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-events = ">=0.9.0" +jupyter-server-terminals = "*" +nbconvert = ">=6.4.4" +nbformat = ">=5.3.0" +overrides = "*" +packaging = "*" +prometheus-client = "*" +pywinpty = {version = "*", markers = "os_name == \"nt\""} +pyzmq = ">=24" +send2trash = ">=1.8.2" +terminado = ">=0.8.3" +tornado = ">=6.2.0" +traitlets = ">=5.6.0" +websocket-client = "*" + +[package.extras] +docs = ["ipykernel", "jinja2", "jupyter-client", "jupyter-server", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"] +test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.4)", "pytest-timeout", "requests"] + +[[package]] +name = "jupyter-server-terminals" +version = "0.4.4" +description = "A Jupyter Server Extension Providing Terminals." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, + {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, +] + +[package.dependencies] +pywinpty = {version = ">=2.0.3", markers = "os_name == \"nt\""} +terminado = ">=0.8.3" + +[package.extras] +docs = ["jinja2", "jupyter-server", "mistune (<3.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"] +test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"] + +[[package]] +name = "jupyterlab" +version = "4.0.8" +description = "JupyterLab computational environment" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyterlab-4.0.8-py3-none-any.whl", hash = "sha256:2ff5aa2a51eb21df241d6011c236e88bd1ff9a5dbb75bebc54472f9c18bfffa4"}, + {file = "jupyterlab-4.0.8.tar.gz", hash = "sha256:c4fe93f977bcc987bd395d7fae5ab02e0c042bf4e0f7c95196f3e2e578c2fb3a"}, +] + +[package.dependencies] +async-lru = ">=1.0.0" +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} +importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""} +ipykernel = "*" +jinja2 = ">=3.0.3" +jupyter-core = "*" +jupyter-lsp = ">=2.0.0" +jupyter-server = ">=2.4.0,<3" +jupyterlab-server = ">=2.19.0,<3" +notebook-shim = ">=0.2" +packaging = "*" +tomli = {version = "*", markers = "python_version < \"3.11\""} +tornado = ">=6.2.0" +traitlets = "*" + +[package.extras] +dev = ["black[jupyter] (==23.10.1)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.292)"] +docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"] +docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] +test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] + +[[package]] +name = "jupyterlab-pygments" +version = "0.2.2" +description = "Pygments theme using JupyterLab CSS variables" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, + {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, +] + +[[package]] +name = "jupyterlab-server" +version = "2.25.2" +description = "A set of server components for JupyterLab and JupyterLab like applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyterlab_server-2.25.2-py3-none-any.whl", hash = "sha256:5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf"}, + {file = "jupyterlab_server-2.25.2.tar.gz", hash = "sha256:bd0ec7a99ebcedc8bcff939ef86e52c378e44c2707e053fcd81d046ce979ee63"}, +] + +[package.dependencies] +babel = ">=2.10" +importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} +jinja2 = ">=3.0.3" +json5 = ">=0.9.0" +jsonschema = ">=4.18.0" +jupyter-server = ">=1.21,<3" +packaging = ">=21.3" +requests = ">=2.31" + +[package.extras] +docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinxcontrib-openapi (>0.8)"] +openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"] +test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] + +[[package]] +name = "jupyterlab-widgets" +version = "3.0.9" +description = "Jupyter interactive widgets for JupyterLab" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, + {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, +] + +[[package]] +name = "langsmith" +version = "0.0.65" +description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langsmith-0.0.65-py3-none-any.whl", hash = "sha256:92450957d1c6b6be814f9b726f3bc751deca684535fb404508ccad7aec1bb049"}, + {file = "langsmith-0.0.65.tar.gz", hash = "sha256:ef20e2e32392fb1a0fc5d171e8de595d868b4153a10cc119d7bf8418192c06b6"}, +] + +[package.dependencies] +pydantic = ">=1,<3" +requests = ">=2,<3" + +[[package]] +name = "markupsafe" +version = "2.1.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, + {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, + {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, + {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, + {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, + {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, + {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, +] + +[[package]] +name = "matplotlib-inline" +version = "0.1.6" +description = "Inline Matplotlib backend for Jupyter" +optional = false +python-versions = ">=3.5" +files = [ + {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, + {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, +] + +[package.dependencies] +traitlets = "*" + +[[package]] +name = "mistune" +version = "3.0.2" +description = "A sane and fast Markdown parser with useful plugins and renderers" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, + {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, +] + +[[package]] +name = "mypy" +version = "0.991" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"}, + {file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"}, + {file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"}, + {file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"}, + {file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"}, + {file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"}, + {file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"}, + {file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"}, + {file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"}, + {file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"}, + {file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"}, + {file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"}, + {file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"}, + {file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"}, + {file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"}, + {file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"}, + {file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"}, + {file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"}, + {file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"}, + {file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"}, + {file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"}, + {file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"}, + {file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"}, + {file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"}, + {file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"}, + {file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"}, + {file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"}, + {file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"}, + {file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"}, + {file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"}, +] + +[package.dependencies] +mypy-extensions = ">=0.4.3" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=3.10" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "nbclient" +version = "0.9.0" +description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"}, + {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"}, +] + +[package.dependencies] +jupyter-client = ">=6.1.12" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +nbformat = ">=5.1" +traitlets = ">=5.4" + +[package.extras] +dev = ["pre-commit"] +docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"] +test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] + +[[package]] +name = "nbconvert" +version = "7.11.0" +description = "Converting Jupyter Notebooks" +optional = false +python-versions = ">=3.8" +files = [ + {file = "nbconvert-7.11.0-py3-none-any.whl", hash = "sha256:d1d417b7f34a4e38887f8da5bdfd12372adf3b80f995d57556cb0972c68909fe"}, + {file = "nbconvert-7.11.0.tar.gz", hash = "sha256:abedc01cf543177ffde0bfc2a69726d5a478f6af10a332fc1bf29fcb4f0cf000"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +bleach = "!=5.0.0" +defusedxml = "*" +importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} +jinja2 = ">=3.0" +jupyter-core = ">=4.7" +jupyterlab-pygments = "*" +markupsafe = ">=2.0" +mistune = ">=2.0.3,<4" +nbclient = ">=0.5.0" +nbformat = ">=5.7" +packaging = "*" +pandocfilters = ">=1.4.1" +pygments = ">=2.4.1" +tinycss2 = "*" +traitlets = ">=5.1" + +[package.extras] +all = ["nbconvert[docs,qtpdf,serve,test,webpdf]"] +docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)", "sphinxcontrib-spelling"] +qtpdf = ["nbconvert[qtpng]"] +qtpng = ["pyqtwebengine (>=5.15)"] +serve = ["tornado (>=6.1)"] +test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest"] +webpdf = ["playwright"] + +[[package]] +name = "nbformat" +version = "5.9.2" +description = "The Jupyter Notebook format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"}, + {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"}, +] + +[package.dependencies] +fastjsonschema = "*" +jsonschema = ">=2.6" +jupyter-core = "*" +traitlets = ">=5.1" + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] +test = ["pep440", "pre-commit", "pytest", "testpath"] + +[[package]] +name = "nest-asyncio" +version = "1.5.8" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"}, + {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, +] + +[[package]] +name = "notebook" +version = "7.0.6" +description = "Jupyter Notebook - A web-based notebook environment for interactive computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "notebook-7.0.6-py3-none-any.whl", hash = "sha256:0fe8f67102fea3744fedf652e4c15339390902ca70c5a31c4f547fa23da697cc"}, + {file = "notebook-7.0.6.tar.gz", hash = "sha256:ec6113b06529019f7f287819af06c97a2baf7a95ac21a8f6e32192898e9f9a58"}, +] + +[package.dependencies] +jupyter-server = ">=2.4.0,<3" +jupyterlab = ">=4.0.2,<5" +jupyterlab-server = ">=2.22.1,<3" +notebook-shim = ">=0.2,<0.3" +tornado = ">=6.2.0" + +[package.extras] +dev = ["hatch", "pre-commit"] +docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] +test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.22.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] + +[[package]] +name = "notebook-shim" +version = "0.2.3" +description = "A shim layer for notebook traits and config" +optional = false +python-versions = ">=3.7" +files = [ + {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"}, + {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"}, +] + +[package.dependencies] +jupyter-server = ">=1.8,<3" + +[package.extras] +test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"] + +[[package]] +name = "overrides" +version = "7.4.0" +description = "A decorator to automatically detect mismatch when overriding a method." +optional = false +python-versions = ">=3.6" +files = [ + {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"}, + {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"}, +] + +[[package]] +name = "packaging" +version = "23.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, +] + +[[package]] +name = "pandocfilters" +version = "1.5.0" +description = "Utilities for writing pandoc filters in python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, + {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, +] + +[[package]] +name = "parso" +version = "0.8.3" +description = "A Python Parser" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, + {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, +] + +[package.extras] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["docopt", "pytest (<6.0.0)"] + +[[package]] +name = "pexpect" +version = "4.8.0" +description = "Pexpect allows easy control of interactive console applications." +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, + {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + +[[package]] +name = "pickleshare" +version = "0.7.5" +description = "Tiny 'shelve'-like database with concurrency support" +optional = false +python-versions = "*" +files = [ + {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, + {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, +] + +[[package]] +name = "pkgutil-resolve-name" +version = "1.3.10" +description = "Resolve a name to an object." +optional = false +python-versions = ">=3.6" +files = [ + {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, + {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, +] + +[[package]] +name = "platformdirs" +version = "4.0.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-4.0.0-py3-none-any.whl", hash = "sha256:118c954d7e949b35437270383a3f2531e99dd93cf7ce4dc8340d3356d30f173b"}, + {file = "platformdirs-4.0.0.tar.gz", hash = "sha256:cb633b2bcf10c51af60beb0ab06d2f1d69064b43abf4c185ca6b28865f3f9731"}, +] + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] + +[[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "prometheus-client" +version = "0.18.0" +description = "Python client for the Prometheus monitoring system." +optional = false +python-versions = ">=3.8" +files = [ + {file = "prometheus_client-0.18.0-py3-none-any.whl", hash = "sha256:8de3ae2755f890826f4b6479e5571d4f74ac17a81345fe69a6778fdb92579184"}, + {file = "prometheus_client-0.18.0.tar.gz", hash = "sha256:35f7a8c22139e2bb7ca5a698e92d38145bc8dc74c1c0bf56f25cca886a764e17"}, +] + +[package.extras] +twisted = ["twisted"] + +[[package]] +name = "prompt-toolkit" +version = "3.0.41" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, + {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, +] + +[package.dependencies] +wcwidth = "*" + +[[package]] +name = "psutil" +version = "5.9.6" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"}, + {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"}, + {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28"}, + {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017"}, + {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c"}, + {file = "psutil-5.9.6-cp27-none-win32.whl", hash = "sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9"}, + {file = "psutil-5.9.6-cp27-none-win_amd64.whl", hash = "sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac"}, + {file = "psutil-5.9.6-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a"}, + {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c"}, + {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4"}, + {file = "psutil-5.9.6-cp36-cp36m-win32.whl", hash = "sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602"}, + {file = "psutil-5.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa"}, + {file = "psutil-5.9.6-cp37-abi3-win32.whl", hash = "sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c"}, + {file = "psutil-5.9.6-cp37-abi3-win_amd64.whl", hash = "sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a"}, + {file = "psutil-5.9.6-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57"}, + {file = "psutil-5.9.6.tar.gz", hash = "sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +description = "Safely evaluate AST nodes without side effects" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + +[package.extras] +tests = ["pytest"] + +[[package]] +name = "pycparser" +version = "2.21" +description = "C parser in Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, + {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, +] + +[[package]] +name = "pydantic" +version = "2.5.1" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-2.5.1-py3-none-any.whl", hash = "sha256:dc5244a8939e0d9a68f1f1b5f550b2e1c879912033b1becbedb315accc75441b"}, + {file = "pydantic-2.5.1.tar.gz", hash = "sha256:0b8be5413c06aadfbe56f6dc1d45c9ed25fd43264414c571135c97dd77c2bedb"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.14.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.14.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:ba44fad1d114539d6a1509966b20b74d2dec9a5b0ee12dd7fd0a1bb7b8785e5f"}, + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4a70d23eedd88a6484aa79a732a90e36701048a1509078d1b59578ef0ea2cdf5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cc24728a1a9cef497697e53b3d085fb4d3bc0ef1ef4d9b424d9cf808f52c146"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab4a2381005769a4af2ffddae74d769e8a4aae42e970596208ec6d615c6fb080"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a12bf088d6fa20e094f9a477bf84bd823651d8b8384f59bcd50eaa92e6a52"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38aed5a1bbc3025859f56d6a32f6e53ca173283cb95348e03480f333b1091e7d"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1767bd3f6370458e60c1d3d7b1d9c2751cc1ad743434e8ec84625a610c8b9195"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7cb0c397f29688a5bd2c0dbd44451bc44ebb9b22babc90f97db5ec3e5bb69977"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9ff737f24b34ed26de62d481ef522f233d3c5927279f6b7229de9b0deb3f76b5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1a39fecb5f0b19faee9a8a8176c805ed78ce45d760259a4ff3d21a7daa4dfc1"}, + {file = "pydantic_core-2.14.3-cp310-none-win32.whl", hash = "sha256:ccbf355b7276593c68fa824030e68cb29f630c50e20cb11ebb0ee450ae6b3d08"}, + {file = "pydantic_core-2.14.3-cp310-none-win_amd64.whl", hash = "sha256:536e1f58419e1ec35f6d1310c88496f0d60e4f182cacb773d38076f66a60b149"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:f1f46700402312bdc31912f6fc17f5ecaaaa3bafe5487c48f07c800052736289"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:88ec906eb2d92420f5b074f59cf9e50b3bb44f3cb70e6512099fdd4d88c2f87c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:056ea7cc3c92a7d2a14b5bc9c9fa14efa794d9f05b9794206d089d06d3433dc7"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:076edc972b68a66870cec41a4efdd72a6b655c4098a232314b02d2bfa3bfa157"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e71f666c3bf019f2490a47dddb44c3ccea2e69ac882f7495c68dc14d4065eac2"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f518eac285c9632be337323eef9824a856f2680f943a9b68ac41d5f5bad7df7c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dbab442a8d9ca918b4ed99db8d89d11b1f067a7dadb642476ad0889560dac79"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0653fb9fc2fa6787f2fa08631314ab7fc8070307bd344bf9471d1b7207c24623"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c54af5069da58ea643ad34ff32fd6bc4eebb8ae0fef9821cd8919063e0aeeaab"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc956f78651778ec1ab105196e90e0e5f5275884793ab67c60938c75bcca3989"}, + {file = "pydantic_core-2.14.3-cp311-none-win32.whl", hash = "sha256:5b73441a1159f1fb37353aaefb9e801ab35a07dd93cb8177504b25a317f4215a"}, + {file = "pydantic_core-2.14.3-cp311-none-win_amd64.whl", hash = "sha256:7349f99f1ef8b940b309179733f2cad2e6037a29560f1b03fdc6aa6be0a8d03c"}, + {file = "pydantic_core-2.14.3-cp311-none-win_arm64.whl", hash = "sha256:ec79dbe23702795944d2ae4c6925e35a075b88acd0d20acde7c77a817ebbce94"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8f5624f0f67f2b9ecaa812e1dfd2e35b256487566585160c6c19268bf2ffeccc"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6c2d118d1b6c9e2d577e215567eedbe11804c3aafa76d39ec1f8bc74e918fd07"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe863491664c6720d65ae438d4efaa5eca766565a53adb53bf14bc3246c72fe0"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:136bc7247e97a921a020abbd6ef3169af97569869cd6eff41b6a15a73c44ea9b"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aeafc7f5bbddc46213707266cadc94439bfa87ecf699444de8be044d6d6eb26f"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e16aaf788f1de5a85c8f8fcc9c1ca1dd7dd52b8ad30a7889ca31c7c7606615b8"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc652c354d3362e2932a79d5ac4bbd7170757a41a62c4fe0f057d29f10bebb"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f1b92e72babfd56585c75caf44f0b15258c58e6be23bc33f90885cebffde3400"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:75f3f534f33651b73f4d3a16d0254de096f43737d51e981478d580f4b006b427"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c9ffd823c46e05ef3eb28b821aa7bc501efa95ba8880b4a1380068e32c5bed47"}, + {file = "pydantic_core-2.14.3-cp312-none-win32.whl", hash = "sha256:12e05a76b223577a4696c76d7a6b36a0ccc491ffb3c6a8cf92d8001d93ddfd63"}, + {file = "pydantic_core-2.14.3-cp312-none-win_amd64.whl", hash = "sha256:1582f01eaf0537a696c846bea92082082b6bfc1103a88e777e983ea9fbdc2a0f"}, + {file = "pydantic_core-2.14.3-cp312-none-win_arm64.whl", hash = "sha256:96fb679c7ca12a512d36d01c174a4fbfd912b5535cc722eb2c010c7b44eceb8e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:71ed769b58d44e0bc2701aa59eb199b6665c16e8a5b8b4a84db01f71580ec448"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:5402ee0f61e7798ea93a01b0489520f2abfd9b57b76b82c93714c4318c66ca06"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaab9dc009e22726c62fe3b850b797e7f0e7ba76d245284d1064081f512c7226"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92486a04d54987054f8b4405a9af9d482e5100d6fe6374fc3303015983fc8bda"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf08b43d1d5d1678f295f0431a4a7e1707d4652576e1d0f8914b5e0213bfeee5"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8ca13480ce16daad0504be6ce893b0ee8ec34cd43b993b754198a89e2787f7e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44afa3c18d45053fe8d8228950ee4c8eaf3b5a7f3b64963fdeac19b8342c987f"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56814b41486e2d712a8bc02a7b1f17b87fa30999d2323bbd13cf0e52296813a1"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3dc2920cc96f9aa40c6dc54256e436cc95c0a15562eb7bd579e1811593c377e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e483b8b913fcd3b48badec54185c150cb7ab0e6487914b84dc7cde2365e0c892"}, + {file = "pydantic_core-2.14.3-cp37-none-win32.whl", hash = "sha256:364dba61494e48f01ef50ae430e392f67ee1ee27e048daeda0e9d21c3ab2d609"}, + {file = "pydantic_core-2.14.3-cp37-none-win_amd64.whl", hash = "sha256:a402ae1066be594701ac45661278dc4a466fb684258d1a2c434de54971b006ca"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:10904368261e4509c091cbcc067e5a88b070ed9a10f7ad78f3029c175487490f"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:260692420028319e201b8649b13ac0988974eeafaaef95d0dfbf7120c38dc000"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c1bf1a7b05a65d3b37a9adea98e195e0081be6b17ca03a86f92aeb8b110f468"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d7abd17a838a52140e3aeca271054e321226f52df7e0a9f0da8f91ea123afe98"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5c51460ede609fbb4fa883a8fe16e749964ddb459966d0518991ec02eb8dfb9"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d06c78074646111fb01836585f1198367b17d57c9f427e07aaa9ff499003e58d"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af452e69446fadf247f18ac5d153b1f7e61ef708f23ce85d8c52833748c58075"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3ad4968711fb379a67c8c755beb4dae8b721a83737737b7bcee27c05400b047"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c5ea0153482e5b4d601c25465771c7267c99fddf5d3f3bdc238ef930e6d051cf"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:96eb10ef8920990e703da348bb25fedb8b8653b5966e4e078e5be382b430f9e0"}, + {file = "pydantic_core-2.14.3-cp38-none-win32.whl", hash = "sha256:ea1498ce4491236d1cffa0eee9ad0968b6ecb0c1cd711699c5677fc689905f00"}, + {file = "pydantic_core-2.14.3-cp38-none-win_amd64.whl", hash = "sha256:2bc736725f9bd18a60eec0ed6ef9b06b9785454c8d0105f2be16e4d6274e63d0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:1ea992659c03c3ea811d55fc0a997bec9dde863a617cc7b25cfde69ef32e55af"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d2b53e1f851a2b406bbb5ac58e16c4a5496038eddd856cc900278fa0da97f3fc"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c7f8e8a7cf8e81ca7d44bea4f181783630959d41b4b51d2f74bc50f348a090f"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8d3b9c91eeb372a64ec6686c1402afd40cc20f61a0866850f7d989b6bf39a41a"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ef3e2e407e4cad2df3c89488a761ed1f1c33f3b826a2ea9a411b0a7d1cccf1b"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f86f20a9d5bee1a6ede0f2757b917bac6908cde0f5ad9fcb3606db1e2968bcf5"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61beaa79d392d44dc19d6f11ccd824d3cccb865c4372157c40b92533f8d76dd0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d41df8e10b094640a6b234851b624b76a41552f637b9fb34dc720b9fe4ef3be4"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c08ac60c3caa31f825b5dbac47e4875bd4954d8f559650ad9e0b225eaf8ed0c"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d8b3932f1a369364606417ded5412c4ffb15bedbcf797c31317e55bd5d920e"}, + {file = "pydantic_core-2.14.3-cp39-none-win32.whl", hash = "sha256:caa94726791e316f0f63049ee00dff3b34a629b0d099f3b594770f7d0d8f1f56"}, + {file = "pydantic_core-2.14.3-cp39-none-win_amd64.whl", hash = "sha256:2494d20e4c22beac30150b4be3b8339bf2a02ab5580fa6553ca274bc08681a65"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:fe272a72c7ed29f84c42fedd2d06c2f9858dc0c00dae3b34ba15d6d8ae0fbaaf"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7e63a56eb7fdee1587d62f753ccd6d5fa24fbeea57a40d9d8beaef679a24bdd6"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7692f539a26265cece1e27e366df5b976a6db6b1f825a9e0466395b314ee48b"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af46f0b7a1342b49f208fed31f5a83b8495bb14b652f621e0a6787d2f10f24ee"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6e2f9d76c00e805d47f19c7a96a14e4135238a7551a18bfd89bb757993fd0933"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:de52ddfa6e10e892d00f747bf7135d7007302ad82e243cf16d89dd77b03b649d"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:38113856c7fad8c19be7ddd57df0c3e77b1b2336459cb03ee3903ce9d5e236ce"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:354db020b1f8f11207b35360b92d95725621eb92656725c849a61e4b550f4acc"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:76fc18653a5c95e5301a52d1b5afb27c9adc77175bf00f73e94f501caf0e05ad"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2646f8270f932d79ba61102a15ea19a50ae0d43b314e22b3f8f4b5fabbfa6e38"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37dad73a2f82975ed563d6a277fd9b50e5d9c79910c4aec787e2d63547202315"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:113752a55a8eaece2e4ac96bc8817f134c2c23477e477d085ba89e3aa0f4dc44"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:8488e973547e8fb1b4193fd9faf5236cf1b7cd5e9e6dc7ff6b4d9afdc4c720cb"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3d1dde10bd9962b1434053239b1d5490fc31a2b02d8950a5f731bc584c7a5a0f"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2c83892c7bf92b91d30faca53bb8ea21f9d7e39f0ae4008ef2c2f91116d0464a"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:849cff945284c577c5f621d2df76ca7b60f803cc8663ff01b778ad0af0e39bb9"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa89919fbd8a553cd7d03bf23d5bc5deee622e1b5db572121287f0e64979476"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf15145b1f8056d12c67255cd3ce5d317cd4450d5ee747760d8d088d85d12a2d"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4cc6bb11f4e8e5ed91d78b9880774fbc0856cb226151b0a93b549c2b26a00c19"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:832d16f248ca0cc96929139734ec32d21c67669dcf8a9f3f733c85054429c012"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b02b5e1f54c3396c48b665050464803c23c685716eb5d82a1d81bf81b5230da4"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:1f2d4516c32255782153e858f9a900ca6deadfb217fd3fb21bb2b60b4e04d04d"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:0a3e51c2be472b7867eb0c5d025b91400c2b73a0823b89d4303a9097e2ec6655"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:df33902464410a1f1a0411a235f0a34e7e129f12cb6340daca0f9d1390f5fe10"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27828f0227b54804aac6fb077b6bb48e640b5435fdd7fbf0c274093a7b78b69c"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e2979dc80246e18e348de51246d4c9b410186ffa3c50e77924bec436b1e36cb"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b28996872b48baf829ee75fa06998b607c66a4847ac838e6fd7473a6b2ab68e7"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca55c9671bb637ce13d18ef352fd32ae7aba21b4402f300a63f1fb1fd18e0364"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:aecd5ed096b0e5d93fb0367fd8f417cef38ea30b786f2501f6c34eabd9062c38"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:44aaf1a07ad0824e407dafc637a852e9a44d94664293bbe7d8ee549c356c8882"}, + {file = "pydantic_core-2.14.3.tar.gz", hash = "sha256:3ad083df8fe342d4d8d00cc1d3c1a23f0dc84fce416eb301e69f1ddbbe124d3f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pygments" +version = "2.17.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pygments-2.17.0-py3-none-any.whl", hash = "sha256:cd0c46944b2551af02ecc15961050182ea120d3895000e2676160820f3421527"}, + {file = "pygments-2.17.0.tar.gz", hash = "sha256:edaa0fa2453d055d0ac94449d1f73ec7bc52c5e318204da1377c1392978c4a8d"}, +] + +[package.extras] +plugins = ["importlib-metadata"] +windows-terminal = ["colorama (>=0.4.6)"] + +[[package]] +name = "pytest" +version = "7.4.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, + {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + +[[package]] +name = "pytest-watcher" +version = "0.3.4" +description = "Automatically rerun your tests on file modifications" +optional = false +python-versions = ">=3.7.0,<4.0.0" +files = [ + {file = "pytest_watcher-0.3.4-py3-none-any.whl", hash = "sha256:edd2bd9c8a1fb14d48c9f4947234065eb9b4c1acedc0bf213b1f12501dfcffd3"}, + {file = "pytest_watcher-0.3.4.tar.gz", hash = "sha256:d39491ba15b589221bb9a78ef4bed3d5d1503aed08209b1a138aeb95b9117a18"}, +] + +[package.dependencies] +tomli = {version = ">=2.0.1,<3.0.0", markers = "python_version < \"3.11\""} +watchdog = ">=2.0.0" + +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "python-json-logger" +version = "2.0.7" +description = "A python library adding a json log formatter" +optional = false +python-versions = ">=3.6" +files = [ + {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"}, + {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, +] + +[[package]] +name = "pytz" +version = "2023.3.post1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"}, + {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, +] + +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + +[[package]] +name = "pywinpty" +version = "2.0.12" +description = "Pseudo terminal support for Windows from Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"}, + {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"}, + {file = "pywinpty-2.0.12-cp312-none-win_amd64.whl", hash = "sha256:1617b729999eb6713590e17665052b1a6ae0ad76ee31e60b444147c5b6a35dca"}, + {file = "pywinpty-2.0.12-cp38-none-win_amd64.whl", hash = "sha256:189380469ca143d06e19e19ff3fba0fcefe8b4a8cc942140a6b863aed7eebb2d"}, + {file = "pywinpty-2.0.12-cp39-none-win_amd64.whl", hash = "sha256:7520575b6546db23e693cbd865db2764097bd6d4ef5dc18c92555904cd62c3d4"}, + {file = "pywinpty-2.0.12.tar.gz", hash = "sha256:8197de460ae8ebb7f5d1701dfa1b5df45b157bb832e92acba316305e18ca00dd"}, +] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "pyzmq" +version = "25.1.1" +description = "Python bindings for 0MQ" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, + {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"}, + {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"}, + {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"}, + {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"}, + {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"}, + {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"}, + {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"}, + {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"}, + {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"}, + {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"}, + {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"}, + {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"}, + {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"}, + {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"}, + {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"}, + {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"}, + {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"}, + {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"}, + {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"}, + {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"}, + {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"}, + {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"}, + {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"}, + {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"}, + {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"}, + {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"}, +] + +[package.dependencies] +cffi = {version = "*", markers = "implementation_name == \"pypy\""} + +[[package]] +name = "qtconsole" +version = "5.5.1" +description = "Jupyter Qt console" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, + {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, +] + +[package.dependencies] +ipykernel = ">=4.1" +jupyter-client = ">=4.1" +jupyter-core = "*" +packaging = "*" +pygments = "*" +pyzmq = ">=17.1" +qtpy = ">=2.4.0" +traitlets = "<5.2.1 || >5.2.1,<5.2.2 || >5.2.2" + +[package.extras] +doc = ["Sphinx (>=1.3)"] +test = ["flaky", "pytest", "pytest-qt"] + +[[package]] +name = "qtpy" +version = "2.4.1" +description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +optional = false +python-versions = ">=3.7" +files = [ + {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, + {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] + +[[package]] +name = "referencing" +version = "0.31.0" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"}, + {file = "referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +optional = false +python-versions = ">=3.7" +files = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "rfc3339-validator" +version = "0.1.4" +description = "A pure python RFC3339 validator" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa"}, + {file = "rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b"}, +] + +[package.dependencies] +six = "*" + +[[package]] +name = "rfc3986-validator" +version = "0.1.1" +description = "Pure python rfc3986 validator" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9"}, + {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, +] + +[[package]] +name = "rpds-py" +version = "0.13.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.13.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1758197cc8d7ff383c07405f188253535b4aa7fa745cbc54d221ae84b18e0702"}, + {file = "rpds_py-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:715df74cbcef4387d623c917f295352127f4b3e0388038d68fa577b4e4c6e540"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8a9cec0f49df9bac252d92f138c0d7708d98828e21fd57db78087d8f50b5656"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c2545bba02f68abdf398ef4990dc77592cc1e5d29438b35b3a3ca34d171fb4b"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95375c44ffb9ea2bc25d67fb66e726ea266ff1572df50b9556fe28a5f3519cd7"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:54e513df45a8a9419e7952ffd26ac9a5b7b1df97fe72530421794b0de29f9d72"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a25f514a53927b6b4bd04a9a6a13b55209df54f548660eeed673336c0c946d14"}, + {file = "rpds_py-0.13.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1a920fa679ec2758411d66bf68840b0a21317b9954ab0e973742d723bb67709"}, + {file = "rpds_py-0.13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f9339d1404b87e6d8cb35e485945753be57a99ab9bb389f42629215b2f6bda0f"}, + {file = "rpds_py-0.13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c99f9dda2c959f7bb69a7125e192c74fcafb7a534a95ccf49313ae3a04807804"}, + {file = "rpds_py-0.13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bad6758df5f1042b35683bd1811d5432ac1b17700a5a2a51fdc293f7df5f7827"}, + {file = "rpds_py-0.13.0-cp310-none-win32.whl", hash = "sha256:2a29ec68fa9655ce9501bc6ae074b166e8b45c2dfcd2d71d90d1a61758ed8c73"}, + {file = "rpds_py-0.13.0-cp310-none-win_amd64.whl", hash = "sha256:244be953f13f148b0071d67a610f89cd72eb5013a147e517d6ca3f3f3b7e0380"}, + {file = "rpds_py-0.13.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:240279ca0b2afd6d4710afce1c94bf9e75fc161290bf62c0feba64d64780d80b"}, + {file = "rpds_py-0.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:25c9727da2dabc93664a18eda7a70feedf478f0c4c8294e4cdba7f60a479a246"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981e46e1e5064f95460381bff4353783b4b5ce351c930e5b507ebe0278c61dac"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6052bb47ea583646b8ff562acacb9a2ec5ec847267049cbae3919671929e94c6"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87f591ff8cc834fa01ca5899ab5edcd7ee590492a9cdcf43424ac142e731ce3e"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:62772259b3381e2aabf274c74fd1e1ac03b0524de0a6593900684becfa8cfe4b"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4de9d20fe68c16b4d97f551a09920745add0c86430262230528b83c2ed2fe90"}, + {file = "rpds_py-0.13.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b70a54fb628c1d6400e351674a31ba63d2912b8c5b707f99b408674a5d8b69ab"}, + {file = "rpds_py-0.13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2063ab9cd1be7ef6b5ed0f408e2bdf32c060b6f40c097a468f32864731302636"}, + {file = "rpds_py-0.13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:84f7f3f18d29a1c645729634003d21d84028bd9c2fd78eba9d028998f46fa5aa"}, + {file = "rpds_py-0.13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7c7ddc8d1a64623068da5a15e28001fbd0f0aff754aae7a75a4be5042191638"}, + {file = "rpds_py-0.13.0-cp311-none-win32.whl", hash = "sha256:8a33d2b6340261191bb59adb5a453fa6c7d99de85552bd4e8196411f0509c9bf"}, + {file = "rpds_py-0.13.0-cp311-none-win_amd64.whl", hash = "sha256:8b9c1dd90461940315981499df62a627571c4f0992e8bafc5396d33916224cac"}, + {file = "rpds_py-0.13.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:15a2d542de5cbfc6abddc4846d9412b59f8ee9c8dfa0b9c92a29321297c91745"}, + {file = "rpds_py-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8dd69e01b29ff45a0062cad5c480d8aa9301c3ef09da471f86337a78eb2d3405"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efdd02971a02f98492a72b25484f1f6125fb9f2166e48cc4c9bfa563349c851b"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:91ca9aaee7ccdfa66d800b5c4ec634fefca947721bab52d6ad2f6350969a3771"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afcec1f5b09d0db70aeb2d90528a9164acb61841a3124e28f6ac0137f4c36cb4"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c6824673f66c47f7ee759c21e973bfce3ceaf2c25cb940cb45b41105dc914e8"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50b6d80925dfeb573fc5e38582fb9517c6912dc462cc858a11c8177b0837127a"}, + {file = "rpds_py-0.13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3a1a38512925829784b5dc38591c757b80cfce115c72c594dc59567dab62b9c4"}, + {file = "rpds_py-0.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:977c6123c359dcc70ce3161b781ab70b0d342de2666944b776617e01a0a7822a"}, + {file = "rpds_py-0.13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c472409037e05ed87b99430f97a6b82130328bb977502813547e8ee6a3392502"}, + {file = "rpds_py-0.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:28bb22019f4a783ea06a6b81437d5996551869e8a722ee8720b744f7684d97f4"}, + {file = "rpds_py-0.13.0-cp312-none-win32.whl", hash = "sha256:46be9c0685cce2ea02151aa8308f2c1b78581be41a5dd239448a941a210ef5dd"}, + {file = "rpds_py-0.13.0-cp312-none-win_amd64.whl", hash = "sha256:3c5b9ad4d3e05dfcf8629f0d534f92610e9805dbce2fcb9b3c801ddb886431d5"}, + {file = "rpds_py-0.13.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:66eb5aa36e857f768c598d2082fafb733eaf53e06e1169c6b4de65636e04ffd0"}, + {file = "rpds_py-0.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c9f4c2b7d989426e9fe9b720211172cf10eb5f7aa16c63de2e5dc61457abcf35"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1e37dfffe8959a492b7b331995f291847a41a035b4aad82d6060f38e8378a2b"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8220321f2dccd9d66f72639185247cb7bbdd90753bf0b6bfca0fa31dba8af23c"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8f1d466a9747213d3cf7e1afec849cc51edb70d5b4ae9a82eca0f172bfbb6d0"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c4c4b4ff3de834ec5c1c690e5a18233ca78547d003eb83664668ccf09ef1398"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:525d19ef0a999229ef0f0a7687ab2c9a00d1b6a47a005006f4d8c4b8975fdcec"}, + {file = "rpds_py-0.13.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0982b59d014efb84a57128e7e69399fb29ad8f2da5b0a5bcbfd12e211c00492e"}, + {file = "rpds_py-0.13.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f714dd5b705f1c394d1b361d96486c4981055c434a7eafb1a3147ac75e34a3de"}, + {file = "rpds_py-0.13.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:766b573a964389ef0d91a26bb31e1b59dbc5d06eff7707f3dfcec23d93080ba3"}, + {file = "rpds_py-0.13.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2ed65ad3fc5065d13e31e90794e0b52e405b63ae4fab1080caeaadc10a3439c5"}, + {file = "rpds_py-0.13.0-cp38-none-win32.whl", hash = "sha256:9645f7fe10a68b2396d238250b4b264c2632d2eb6ce2cb90aa0fe08adee194be"}, + {file = "rpds_py-0.13.0-cp38-none-win_amd64.whl", hash = "sha256:42d0ad129c102856a364ccc7d356faec017af86b3543a8539795f22b6cabad11"}, + {file = "rpds_py-0.13.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:95c11647fac2a3515ea2614a79e14b7c75025724ad54c91c7db4a6ea5c25ef19"}, + {file = "rpds_py-0.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9435bf4832555c4f769c6be9401664357be33d5f5d8dc58f5c20fb8d21e2c45d"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b1d671a74395344239ee3adbcd8c496525f6a2b2e54c40fec69620a31a8dcb"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13c8061115f1468de6ffdfb1d31b446e1bd814f1ff6e556862169aacb9fbbc5d"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a78861123b002725633871a2096c3a4313224aab3d11b953dced87cfba702418"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97c1be5a018cdad54fa7e5f7d36b9ab45ef941a1d185987f18bdab0a42344012"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e33b17915c8e4fb2ea8b91bb4c46cba92242c63dd38b87e869ead5ba217e2970"}, + {file = "rpds_py-0.13.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:153b6d8cf7ae4b9ffd09de6abeda661e351e3e06eaafd18a8c104ea00099b131"}, + {file = "rpds_py-0.13.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:da2852201e8e00c86be82c43d6893e6c380ef648ae53f337ffd1eaa35e3dfb8a"}, + {file = "rpds_py-0.13.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a2383f400691fd7bd63347d4d75eb2fd525de9d901799a33a4e896c9885609f8"}, + {file = "rpds_py-0.13.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d5bf560634ea6e9a59ceb2181a6cd6195a03f48cef9a400eb15e197e18f14548"}, + {file = "rpds_py-0.13.0-cp39-none-win32.whl", hash = "sha256:fdaef49055cc0c701fb17b9b34a38ef375e5cdb230b3722d4a12baf9b7cbc6d3"}, + {file = "rpds_py-0.13.0-cp39-none-win_amd64.whl", hash = "sha256:26660c74a20fe249fad75ca00bbfcf60e57c3fdbde92971c88a20e07fea1de64"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:28324f2f0247d407daabf7ff357ad9f36126075c92a0cf5319396d96ff4e1248"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b431c2c0ff1ea56048a2b066d99d0c2d151ae7625b20be159b7e699f3e80390b"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7472bd60a8293217444bdc6a46e516feb8d168da44d5f3fccea0336e88e3b79a"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:169063f346b8fd84f47d986c9c48e6094eb38b839c1287e7cb886b8a2b32195d"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eef7ee7c70f8b8698be468d54f9f5e01804f3a1dd5657e8a96363dbd52b9b5ec"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:762013dd59df12380c5444f61ccbf9ae1297027cabbd7aa25891f724ebf8c8f7"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:152570689a27ae0be1d5f50b21dad38d450b9227d0974f23bd400400ea087e88"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d70a93a40e55da117c511ddc514642bc7d59a95a99137168a5f3f2f876b47962"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:e6c6fed07d13b9e0fb689356c40c81f1aa92e3c9d91d8fd5816a0348ccd999f7"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:cdded3cf9e36840b09ccef714d5fa74a03f4eb6cf81e694226ed9cb5e6f90de0"}, + {file = "rpds_py-0.13.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e1f40faf406c52c7ae7d208b9140377c06397248978ccb03fbfbb30a0571e359"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:c10326e30c97a95b7e1d75e5200ef0b9827aa0f861e331e43b15dfdfd63e669b"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:afde37e3763c602d0385bce5c12f262e7b1dd2a0f323e239fa9d7b2d4d5d8509"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4084ab6840bd4d79eff3b5f497add847a7db31ce5a0c2d440c90b2d2b7011857"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1c9c9cb48ab77ebfa47db25b753f594d4f44959cfe43b713439ca6e3c9329671"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:533d728ea5ad5253af3395102723ca8a77b62de47b2295155650c9a88fcdeec8"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f22cab655b41033d430f20266bf563b35038a7f01c9a099b0ccfd30a7fb9247"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a0507342c37132813449393e6e6f351bbff376031cfff1ee6e616402ac7908"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4eb1faf8e2ee9a2de3cb3ae4c8c355914cdc85f2cd7f27edf76444c9550ce1e7"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:a61a152d61e3ae26e0bbba7b2f568f6f25ca0abdeb6553eca7e7c45b59d9b1a9"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:e499bf2200eb74774a6f85a7465e3bc5273fa8ef0055590d97a88c1e7ea02eea"}, + {file = "rpds_py-0.13.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:1e5becd0de924616ca9a12abeb6458568d1dc8fe5c670d5cdb738402a8a8429d"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:70cfe098d915f566eeebcb683f49f9404d2f948432891b6e075354336eda9dfb"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2e73511e88368f93c24efe7c9a20b319eaa828bc7431f8a17713efb9e31a39fa"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c07cb9bcccd08f9bc2fd05bf586479df4272ea5a6a70fbcb59b018ed48a5a84d"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8c4e84016ba225e09df20fed8befe8c68d14fbeff6078f4a0ff907ae2095e17e"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ad465e5a70580ca9c1944f43a9a71bca3a7b74554347fc96ca0479eca8981f9"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:189aebd44a07fa7b7966cf78b85bde8335b0b6c3b1c4ef5589f8c03176830107"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f50ca0460f1f7a89ab9b8355d83ac993d5998ad4218e76654ecf8afe648d8aa"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6c225011467021879c0482316e42d8a28852fc29f0c15d2a435ff457cadccd4"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1e63b32b856c0f08a56b76967d61b6ad811d8d330a8aebb9d21afadd82a296f6"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:7e5fbe9800f09c56967fda88c4d9272955e781699a66102bd098f22511a3f260"}, + {file = "rpds_py-0.13.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:fea99967d4a978ce95dd52310bcb4a943b77c61725393bca631b0908047d6e2f"}, + {file = "rpds_py-0.13.0.tar.gz", hash = "sha256:35cc91cbb0b775705e0feb3362490b8418c408e9e3c3b9cb3b02f6e495f03ee7"}, +] + +[[package]] +name = "ruff" +version = "0.1.6" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:88b8cdf6abf98130991cbc9f6438f35f6e8d41a02622cc5ee130a02a0ed28703"}, + {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c549ed437680b6105a1299d2cd30e4964211606eeb48a0ff7a93ef70b902248"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cf5f701062e294f2167e66d11b092bba7af6a057668ed618a9253e1e90cfd76"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05991ee20d4ac4bb78385360c684e4b417edd971030ab12a4fbd075ff535050e"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87455a0c1f739b3c069e2f4c43b66479a54dea0276dd5d4d67b091265f6fd1dc"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:683aa5bdda5a48cb8266fcde8eea2a6af4e5700a392c56ea5fb5f0d4bfdc0240"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:137852105586dcbf80c1717facb6781555c4e99f520c9c827bd414fac67ddfb6"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd98138a98d48a1c36c394fd6b84cd943ac92a08278aa8ac8c0fdefcf7138f35"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0cd909d25f227ac5c36d4e7e681577275fb74ba3b11d288aff7ec47e3ae745"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8fd1c62a47aa88a02707b5dd20c5ff20d035d634aa74826b42a1da77861b5ff"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd89b45d374935829134a082617954120d7a1470a9f0ec0e7f3ead983edc48cc"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:491262006e92f825b145cd1e52948073c56560243b55fb3b4ecb142f6f0e9543"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea284789861b8b5ca9d5443591a92a397ac183d4351882ab52f6296b4fdd5462"}, + {file = "ruff-0.1.6-py3-none-win32.whl", hash = "sha256:1610e14750826dfc207ccbcdd7331b6bd285607d4181df9c1c6ae26646d6848a"}, + {file = "ruff-0.1.6-py3-none-win_amd64.whl", hash = "sha256:4558b3e178145491e9bc3b2ee3c4b42f19d19384eaa5c59d10acf6e8f8b57e33"}, + {file = "ruff-0.1.6-py3-none-win_arm64.whl", hash = "sha256:03910e81df0d8db0e30050725a5802441c2022ea3ae4fe0609b76081731accbc"}, + {file = "ruff-0.1.6.tar.gz", hash = "sha256:1b09f29b16c6ead5ea6b097ef2764b42372aebe363722f1605ecbcd2b9207184"}, +] + +[[package]] +name = "send2trash" +version = "1.8.2" +description = "Send file to trash natively under Mac OS X, Windows and Linux" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +files = [ + {file = "Send2Trash-1.8.2-py3-none-any.whl", hash = "sha256:a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679"}, + {file = "Send2Trash-1.8.2.tar.gz", hash = "sha256:c132d59fa44b9ca2b1699af5c86f57ce9f4c5eb56629d5d55fbb7a35f84e2312"}, +] + +[package.extras] +nativelib = ["pyobjc-framework-Cocoa", "pywin32"] +objc = ["pyobjc-framework-Cocoa"] +win32 = ["pywin32"] + +[[package]] +name = "setuptools" +version = "67.8.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "setuptools-67.8.0-py3-none-any.whl", hash = "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f"}, + {file = "setuptools-67.8.0.tar.gz", hash = "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + +[[package]] +name = "soupsieve" +version = "2.5" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +files = [ + {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, + {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +description = "Extract data from python stack frames and tracebacks for informative displays" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + +[[package]] +name = "syrupy" +version = "4.6.0" +description = "Pytest Snapshot Test Utility" +optional = false +python-versions = ">=3.8.1,<4" +files = [ + {file = "syrupy-4.6.0-py3-none-any.whl", hash = "sha256:747aae1bcf3cb3249e33b1e6d81097874d23615982d5686ebe637875b0775a1b"}, + {file = "syrupy-4.6.0.tar.gz", hash = "sha256:231b1f5d00f1f85048ba81676c79448076189c4aef4d33f21ae32f3b4c565a54"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<8.0.0" + +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + +[[package]] +name = "terminado" +version = "0.18.0" +description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, + {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, +] + +[package.dependencies] +ptyprocess = {version = "*", markers = "os_name != \"nt\""} +pywinpty = {version = ">=1.1.0", markers = "os_name == \"nt\""} +tornado = ">=6.1.0" + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] +typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] + +[[package]] +name = "tinycss2" +version = "1.2.1" +description = "A tiny CSS parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"}, + {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"}, +] + +[package.dependencies] +webencodings = ">=0.4" + +[package.extras] +doc = ["sphinx", "sphinx_rtd_theme"] +test = ["flake8", "isort", "pytest"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "tornado" +version = "6.3.3" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +optional = false +python-versions = ">= 3.8" +files = [ + {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:502fba735c84450974fec147340016ad928d29f1e91f49be168c0a4c18181e1d"}, + {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:805d507b1f588320c26f7f097108eb4023bbaa984d63176d1652e184ba24270a"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bd19ca6c16882e4d37368e0152f99c099bad93e0950ce55e71daed74045908f"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ac51f42808cca9b3613f51ffe2a965c8525cb1b00b7b2d56828b8045354f76a"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71a8db65160a3c55d61839b7302a9a400074c9c753040455494e2af74e2501f2"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ceb917a50cd35882b57600709dd5421a418c29ddc852da8bcdab1f0db33406b0"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:7d01abc57ea0dbb51ddfed477dfe22719d376119844e33c661d873bf9c0e4a16"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:9dc4444c0defcd3929d5c1eb5706cbe1b116e762ff3e0deca8b715d14bf6ec17"}, + {file = "tornado-6.3.3-cp38-abi3-win32.whl", hash = "sha256:65ceca9500383fbdf33a98c0087cb975b2ef3bfb874cb35b8de8740cf7f41bd3"}, + {file = "tornado-6.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:22d3c2fa10b5793da13c807e6fc38ff49a4f6e1e3868b0a6f4164768bb8e20f5"}, + {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"}, +] + +[[package]] +name = "traitlets" +version = "5.13.0" +description = "Traitlets Python configuration system" +optional = false +python-versions = ">=3.8" +files = [ + {file = "traitlets-5.13.0-py3-none-any.whl", hash = "sha256:baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619"}, + {file = "traitlets-5.13.0.tar.gz", hash = "sha256:9b232b9430c8f57288c1024b34a8f0251ddcc47268927367a0dd3eeaca40deb5"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] + +[[package]] +name = "types-python-dateutil" +version = "2.8.19.14" +description = "Typing stubs for python-dateutil" +optional = false +python-versions = "*" +files = [ + {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"}, + {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, +] + +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + +[[package]] +name = "types-requests" +version = "2.31.0.10" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.10.tar.gz", hash = "sha256:dc5852a76f1eaf60eafa81a2e50aefa3d1f015c34cf0cba130930866b1b22a92"}, + {file = "types_requests-2.31.0.10-py3-none-any.whl", hash = "sha256:b32b9a86beffa876c0c3ac99a4cd3b8b51e973fb8e3bd4e0a6bb32c7efad80fc"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "typing-extensions" +version = "4.8.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, + {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, +] + +[[package]] +name = "uri-template" +version = "1.3.0" +description = "RFC 6570 URI Template Processor" +optional = false +python-versions = ">=3.7" +files = [ + {file = "uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7"}, + {file = "uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363"}, +] + +[package.extras] +dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] + +[[package]] +name = "urllib3" +version = "2.1.0" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, + {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "watchdog" +version = "3.0.0" +description = "Filesystem events monitoring" +optional = false +python-versions = ">=3.7" +files = [ + {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:336adfc6f5cc4e037d52db31194f7581ff744b67382eb6021c868322e32eef41"}, + {file = "watchdog-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a70a8dcde91be523c35b2bf96196edc5730edb347e374c7de7cd20c43ed95397"}, + {file = "watchdog-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adfdeab2da79ea2f76f87eb42a3ab1966a5313e5a69a0213a3cc06ef692b0e96"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b57a1e730af3156d13b7fdddfc23dea6487fceca29fc75c5a868beed29177ae"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ade88d0d778b1b222adebcc0927428f883db07017618a5e684fd03b83342bd9"}, + {file = "watchdog-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e447d172af52ad204d19982739aa2346245cc5ba6f579d16dac4bfec226d2e7"}, + {file = "watchdog-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9fac43a7466eb73e64a9940ac9ed6369baa39b3bf221ae23493a9ec4d0022674"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8ae9cda41fa114e28faf86cb137d751a17ffd0316d1c34ccf2235e8a84365c7f"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f70b4aa53bd743729c7475d7ec41093a580528b100e9a8c5b5efe8899592fc"}, + {file = "watchdog-3.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4f94069eb16657d2c6faada4624c39464f65c05606af50bb7902e036e3219be3"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c5f84b5194c24dd573fa6472685b2a27cc5a17fe5f7b6fd40345378ca6812e3"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3aa7f6a12e831ddfe78cdd4f8996af9cf334fd6346531b16cec61c3b3c0d8da0"}, + {file = "watchdog-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:233b5817932685d39a7896b1090353fc8efc1ef99c9c054e46c8002561252fb8"}, + {file = "watchdog-3.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13bbbb462ee42ec3c5723e1205be8ced776f05b100e4737518c67c8325cf6100"}, + {file = "watchdog-3.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8f3ceecd20d71067c7fd4c9e832d4e22584318983cabc013dbf3f70ea95de346"}, + {file = "watchdog-3.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9d8c8ec7efb887333cf71e328e39cffbf771d8f8f95d308ea4125bf5f90ba64"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0e06ab8858a76e1219e68c7573dfeba9dd1c0219476c5a44d5333b01d7e1743a"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:d00e6be486affb5781468457b21a6cbe848c33ef43f9ea4a73b4882e5f188a44"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:c07253088265c363d1ddf4b3cdb808d59a0468ecd017770ed716991620b8f77a"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:5113334cf8cf0ac8cd45e1f8309a603291b614191c9add34d33075727a967709"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:51f90f73b4697bac9c9a78394c3acbbd331ccd3655c11be1a15ae6fe289a8c83"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:ba07e92756c97e3aca0912b5cbc4e5ad802f4557212788e72a72a47ff376950d"}, + {file = "watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d429c2430c93b7903914e4db9a966c7f2b068dd2ebdd2fa9b9ce094c7d459f33"}, + {file = "watchdog-3.0.0-py3-none-win32.whl", hash = "sha256:3ed7c71a9dccfe838c2f0b6314ed0d9b22e77d268c67e015450a29036a81f60f"}, + {file = "watchdog-3.0.0-py3-none-win_amd64.whl", hash = "sha256:4c9956d27be0bb08fc5f30d9d0179a855436e655f046d288e2bcc11adfae893c"}, + {file = "watchdog-3.0.0-py3-none-win_ia64.whl", hash = "sha256:5d9f3a10e02d7371cd929b5d8f11e87d4bad890212ed3901f9b4d68767bee759"}, + {file = "watchdog-3.0.0.tar.gz", hash = "sha256:4d98a320595da7a7c5a18fc48cb633c2e73cda78f93cac2ef42d42bf609a33f9"}, +] + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + +[[package]] +name = "wcwidth" +version = "0.2.10" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.10-py2.py3-none-any.whl", hash = "sha256:aec5179002dd0f0d40c456026e74a729661c9d468e1ed64405e3a6c2176ca36f"}, + {file = "wcwidth-0.2.10.tar.gz", hash = "sha256:390c7454101092a6a5e43baad8f83de615463af459201709556b6e4b1c861f97"}, +] + +[[package]] +name = "webcolors" +version = "1.13" +description = "A library for working with the color formats defined by HTML and CSS." +optional = false +python-versions = ">=3.7" +files = [ + {file = "webcolors-1.13-py3-none-any.whl", hash = "sha256:29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf"}, + {file = "webcolors-1.13.tar.gz", hash = "sha256:c225b674c83fa923be93d235330ce0300373d02885cef23238813b0d5668304a"}, +] + +[package.extras] +docs = ["furo", "sphinx", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinxext-opengraph"] +tests = ["pytest", "pytest-cov"] + +[[package]] +name = "webencodings" +version = "0.5.1" +description = "Character encoding aliases for legacy web content" +optional = false +python-versions = "*" +files = [ + {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, + {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, +] + +[[package]] +name = "websocket-client" +version = "1.6.4" +description = "WebSocket client for Python with low level API options" +optional = false +python-versions = ">=3.8" +files = [ + {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"}, + {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"}, +] + +[package.extras] +docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] +optional = ["python-socks", "wsaccel"] +test = ["websockets"] + +[[package]] +name = "widgetsnbextension" +version = "4.0.9" +description = "Jupyter interactive widgets for Jupyter Notebook" +optional = false +python-versions = ">=3.7" +files = [ + {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, + {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, +] + +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.8.1,<4.0" +content-hash = "b08d47f726dd194af0f801d300402b174c8db96a4184cc1136cb8e5a0e287190" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml new file mode 100644 index 00000000000..b7b8a3184eb --- /dev/null +++ b/libs/core/pyproject.toml @@ -0,0 +1,85 @@ +[tool.poetry] +name = "langchain-core" +version = "0.0.1" +description = "Building applications with LLMs through composability" +authors = [] +license = "MIT" +readme = "README.md" +repository = "https://github.com/langchain-ai/langchain" + + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +pydantic = ">=1,<3" +langsmith = "~0.0.63" +tenacity = "^8.1.0" +jsonpatch = "^1.33" + +[tool.poetry.group.lint.dependencies] +ruff = "^0.1.5" + +[tool.poetry.group.typing.dependencies] +mypy = "^0.991" +types-pyyaml = "^6.0.12.2" +types-requests = "^2.28.11.5" + +[tool.poetry.group.dev.dependencies] +jupyter = "^1.0.0" +setuptools = "^67.6.1" + +[tool.poetry.group.test.dependencies] +# The only dependencies that should be added are +# dependencies used for running tests (e.g., pytest, freezegun, response). +# Any dependencies that do not meet that criteria will be removed. +pytest = "^7.3.0" +freezegun = "^1.2.2" +pytest-mock = "^3.10.0" +syrupy = "^4.0.2" +pytest-watcher = "^0.3.4" +pytest-asyncio = "^0.21.1" + + +[tool.poetry.group.test_integration] +optional = true +dependencies = {} + +[tool.ruff] +select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort +] + +[tool.mypy] +ignore_missing_imports = "True" +disallow_untyped_defs = "True" +exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"] + +[tool.coverage.run] +omit = [ + "tests/*", +] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +# --strict-markers will raise errors on unknown marks. +# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks +# +# https://docs.pytest.org/en/7.1.x/reference/reference.html +# --strict-config any warnings encountered while parsing the `pytest` +# section of the configuration file raise errors. +# +# https://github.com/tophat/syrupy +# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. +addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" +# Registering custom markers. +# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers +markers = [ + "requires: mark tests as requiring a specific library", + "asyncio: mark tests as requiring asyncio", + "compile: mark placeholder test used to compile integration tests without running them", +] +asyncio_mode = "auto" diff --git a/libs/core/tests/__init__.py b/libs/core/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/__init__.py b/libs/core/tests/unit_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/_api/__init__.py b/libs/core/tests/unit_tests/_api/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py similarity index 98% rename from libs/langchain/tests/unit_tests/_api/test_deprecation.py rename to libs/core/tests/unit_tests/_api/test_deprecation.py index cecb034d4d6..238ba231fa3 100644 --- a/libs/langchain/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -3,8 +3,8 @@ from typing import Any, Dict import pytest -from langchain._api.deprecation import deprecated, warn_deprecated -from langchain.pydantic_v1 import BaseModel +from langchain_core._api.deprecation import deprecated, warn_deprecated +from langchain_core.pydantic_v1 import BaseModel @pytest.mark.parametrize( diff --git a/libs/langchain/tests/unit_tests/_api/test_imports.py b/libs/core/tests/unit_tests/_api/test_imports.py similarity index 86% rename from libs/langchain/tests/unit_tests/_api/test_imports.py rename to libs/core/tests/unit_tests/_api/test_imports.py index 7b7135b238f..440f76d9574 100644 --- a/libs/langchain/tests/unit_tests/_api/test_imports.py +++ b/libs/core/tests/unit_tests/_api/test_imports.py @@ -1,4 +1,4 @@ -from langchain._api import __all__ +from langchain_core._api import __all__ EXPECTED_ALL = [ "deprecated", diff --git a/libs/langchain/tests/unit_tests/_api/test_path.py b/libs/core/tests/unit_tests/_api/test_path.py similarity index 87% rename from libs/langchain/tests/unit_tests/_api/test_path.py rename to libs/core/tests/unit_tests/_api/test_path.py index 73da680eab9..89428c7cfcb 100644 --- a/libs/langchain/tests/unit_tests/_api/test_path.py +++ b/libs/core/tests/unit_tests/_api/test_path.py @@ -1,6 +1,6 @@ from pathlib import Path -from langchain._api import path +from langchain_core._api import path HERE = Path(__file__).parent @@ -10,7 +10,7 @@ ROOT = HERE.parent.parent.parent def test_as_import_path() -> None: """Test that the path is converted to a LangChain import path.""" # Verify that default paths are correct - assert path.PACKAGE_DIR == ROOT / "langchain" + assert path.PACKAGE_DIR == ROOT / "langchain_core" # Verify that as import path works correctly assert path.as_import_path(HERE, relative_to=ROOT) == "tests.unit_tests._api" assert ( diff --git a/libs/core/tests/unit_tests/data/prompt_file.txt b/libs/core/tests/unit_tests/data/prompt_file.txt new file mode 100644 index 00000000000..0681c36f48e --- /dev/null +++ b/libs/core/tests/unit_tests/data/prompt_file.txt @@ -0,0 +1,2 @@ +Question: {question} +Answer: \ No newline at end of file diff --git a/libs/core/tests/unit_tests/data/prompts/prompt_extra_args.json b/libs/core/tests/unit_tests/data/prompts/prompt_extra_args.json new file mode 100644 index 00000000000..4bfc4fdcc4b --- /dev/null +++ b/libs/core/tests/unit_tests/data/prompts/prompt_extra_args.json @@ -0,0 +1,5 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test.", + "bad_var": 1 +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/data/prompts/prompt_missing_args.json b/libs/core/tests/unit_tests/data/prompts/prompt_missing_args.json new file mode 100644 index 00000000000..cb69d843e7a --- /dev/null +++ b/libs/core/tests/unit_tests/data/prompts/prompt_missing_args.json @@ -0,0 +1,3 @@ +{ + "input_variables": ["foo"] +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/data/prompts/simple_prompt.json b/libs/core/tests/unit_tests/data/prompts/simple_prompt.json new file mode 100644 index 00000000000..d0f72b1c14f --- /dev/null +++ b/libs/core/tests/unit_tests/data/prompts/simple_prompt.json @@ -0,0 +1,4 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test." +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/examples/example-non-utf8.csv b/libs/core/tests/unit_tests/examples/example-non-utf8.csv new file mode 100644 index 00000000000..2cd131c433d --- /dev/null +++ b/libs/core/tests/unit_tests/examples/example-non-utf8.csv @@ -0,0 +1,11 @@ +sID,»•i–¼,ŒÚ‹q–¼,ŒÚ‹qID,”„ã,‰¿Ši,‘——¿,“s“¹•{Œ§,»•iƒJƒeƒSƒŠ,Š„ˆø +1,"Eldon ƒXƒ^ƒbƒJƒuƒ‹Žû”[’I—pƒx[ƒXAƒvƒ‰ƒ`ƒi",ƒ‚ƒnƒƒhEƒ}ƒbƒLƒ“ƒ^ƒCƒA,3,-213.25,38.94,35,ƒkƒiƒuƒbƒg€B,•ۊǂƮ—,0.8 +2,"1.7—§•ûƒtƒB[ƒg‚̃Rƒ“ƒpƒNƒguƒLƒ…[ƒuvƒIƒtƒBƒX—â‘ ŒÉ",ƒoƒŠ[EƒtƒŒƒ“ƒ`,293,457.81,208.16,68.02,ƒkƒiƒuƒbƒg€B,‰Æ“d»•i,0.58 +3,"Cardinal Slant-D? ƒŠƒ“ƒO ƒoƒCƒ“ƒ_[Aƒwƒr[ƒQ[ƒW ƒrƒj[ƒ‹",ƒoƒŠ[EƒtƒŒƒ“ƒ`,293,46.71,8.69,2.99,ƒkƒiƒuƒbƒg€B,ƒoƒCƒ“ƒ_[‚¨‚æ‚уoƒCƒ“ƒ_[•t‘®•i,0.39 +4,"R380",ƒNƒŒƒCEƒƒ[ƒ“ƒ_ƒ‹,483,1198.97,195.99,3.99,ƒkƒiƒuƒbƒg€B,“d˜b‚Æ’ÊM,0.58 +5,"ƒz[ƒ€ƒY HEPA ‹ó‹C´ò‹@",ƒJƒ‹ƒƒXEƒ\ƒ‹ƒeƒ,515,30.94,21.78,5.94,ƒkƒiƒuƒbƒg€B,‰Æ“d»•i,0.5 +6,"GE ’·Žõ–½‚̉®“à–„žŒ^“ŠŒõŠí“d‹…",ƒJƒ‹ƒƒXEƒ\ƒ‹ƒeƒ,515,4.43,6.64,4.95,ƒkƒiƒuƒbƒg€B,ƒIƒtƒBƒX‰Æ‹ï,0.37 +7,"ƒƒbƒNƒŠƒ“ƒO•t‚«ƒAƒ“ƒOƒ‹DƒoƒCƒ“ƒ_[Aƒ‰ƒxƒ‹ƒzƒ‹ƒ_[",ƒJ[ƒ‹EƒWƒƒƒNƒ\ƒ“,613,-54.04,7.3,7.72,ƒkƒiƒuƒbƒg€B,ƒoƒCƒ“ƒ_[‚¨‚æ‚уoƒCƒ“ƒ_[•t‘®•i,0.38 +8,"SAFCO ƒ‚ƒoƒCƒ‹ƒfƒXƒNƒTƒCƒhƒtƒ@ƒCƒ‹ ƒƒCƒ„[ƒtƒŒ[ƒ€",ƒJ[ƒ‹EƒWƒƒƒNƒ\ƒ“,613,127.70,42.76,6.22,ƒkƒiƒuƒbƒg€B,•ۊǂƮ—, +9,"SAFCO ‹Æ–±—pƒƒCƒ„[ƒVƒFƒ‹ƒt ƒuƒ‰ƒbƒN",ƒ‚ƒjƒJEƒtƒFƒfƒ‹,643,-695.26,138.14,35,ƒkƒiƒuƒbƒg€B,•ۊǂƮ—, +10,"ƒ[ƒƒbƒNƒX 198",ƒhƒƒV[Eƒoƒbƒ_[ƒY,678,-226.36,4.98,8.33,ƒkƒiƒuƒbƒg€B,ކ,0.38 \ No newline at end of file diff --git a/libs/core/tests/unit_tests/examples/example-non-utf8.txt b/libs/core/tests/unit_tests/examples/example-non-utf8.txt new file mode 100644 index 00000000000..60cbb2073e5 --- /dev/null +++ b/libs/core/tests/unit_tests/examples/example-non-utf8.txt @@ -0,0 +1 @@ +Êàêèå-òî êðàêîçÿáðû diff --git a/libs/core/tests/unit_tests/examples/example-utf8.csv b/libs/core/tests/unit_tests/examples/example-utf8.csv new file mode 100644 index 00000000000..df0169beace --- /dev/null +++ b/libs/core/tests/unit_tests/examples/example-utf8.csv @@ -0,0 +1,11 @@ +"Row ID","Product Name","Customer Name","Customer ID","Sales","Price","Shipping Cost","Province","Product Category","Discount" +1,"Eldon Base for stackable storage shelf, platinum",Muhammed MacIntyre,3,-213.25,38.94,35,Nunavut,Storage & Organization,0.8 +2,"1.7 Cubic Foot Compact ""Cube"" Office Refrigerators",Barry French,293,457.81,208.16,68.02,Nunavut,Appliances,0.58 +3,"Cardinal Slant-D® Ring Binder, Heavy Gauge Vinyl",Barry French,293,46.71,8.69,2.99,Nunavut,Binders and Binder Accessories,0.39 +4,R380,Clay Rozendal,483,1198.97,195.99,3.99,Nunavut,Telephones and Communication,0.58 +5,Holmes HEPA Air Purifier,Carlos Soltero,515,30.94,21.78,5.94,Nunavut,Appliances,0.5 +6,G.E. Longer-Life Indoor Recessed Floodlight Bulbs,Carlos Soltero,515,4.43,6.64,4.95,Nunavut,Office Furnishings,0.37 +7,"Angle-D Binders with Locking Rings, Label Holders",Carl Jackson,613,-54.04,7.3,7.72,Nunavut,Binders and Binder Accessories,0.38 +8,"SAFCO Mobile Desk Side File, Wire Frame",Carl Jackson,613,127.70,42.76,6.22,Nunavut,Storage & Organization, +9,"SAFCO Commercial Wire Shelving, Black",Monica Federle,643,-695.26,138.14,35,Nunavut,Storage & Organization, +10,Xerox 198,Dorothy Badders,678,-226.36,4.98,8.33,Nunavut,Paper,0.38 \ No newline at end of file diff --git a/libs/core/tests/unit_tests/examples/example-utf8.txt b/libs/core/tests/unit_tests/examples/example-utf8.txt new file mode 100644 index 00000000000..1bb51996cd6 --- /dev/null +++ b/libs/core/tests/unit_tests/examples/example-utf8.txt @@ -0,0 +1,6 @@ +Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor +incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis +nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu +fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in +culpa qui officia deserunt mollit anim id est laborum. diff --git a/libs/langchain/tests/unit_tests/examples/example_prompt.json b/libs/core/tests/unit_tests/examples/example_prompt.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/example_prompt.json rename to libs/core/tests/unit_tests/examples/example_prompt.json diff --git a/libs/langchain/tests/unit_tests/examples/examples.json b/libs/core/tests/unit_tests/examples/examples.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/examples.json rename to libs/core/tests/unit_tests/examples/examples.json diff --git a/libs/langchain/tests/unit_tests/examples/examples.yaml b/libs/core/tests/unit_tests/examples/examples.yaml similarity index 100% rename from libs/langchain/tests/unit_tests/examples/examples.yaml rename to libs/core/tests/unit_tests/examples/examples.yaml diff --git a/libs/langchain/tests/unit_tests/examples/few_shot_prompt.json b/libs/core/tests/unit_tests/examples/few_shot_prompt.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/few_shot_prompt.json rename to libs/core/tests/unit_tests/examples/few_shot_prompt.json diff --git a/libs/langchain/tests/unit_tests/examples/few_shot_prompt.yaml b/libs/core/tests/unit_tests/examples/few_shot_prompt.yaml similarity index 100% rename from libs/langchain/tests/unit_tests/examples/few_shot_prompt.yaml rename to libs/core/tests/unit_tests/examples/few_shot_prompt.yaml diff --git a/libs/langchain/tests/unit_tests/examples/few_shot_prompt_example_prompt.json b/libs/core/tests/unit_tests/examples/few_shot_prompt_example_prompt.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/few_shot_prompt_example_prompt.json rename to libs/core/tests/unit_tests/examples/few_shot_prompt_example_prompt.json diff --git a/libs/langchain/tests/unit_tests/examples/few_shot_prompt_examples_in.json b/libs/core/tests/unit_tests/examples/few_shot_prompt_examples_in.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/few_shot_prompt_examples_in.json rename to libs/core/tests/unit_tests/examples/few_shot_prompt_examples_in.json diff --git a/libs/langchain/tests/unit_tests/examples/few_shot_prompt_yaml_examples.yaml b/libs/core/tests/unit_tests/examples/few_shot_prompt_yaml_examples.yaml similarity index 100% rename from libs/langchain/tests/unit_tests/examples/few_shot_prompt_yaml_examples.yaml rename to libs/core/tests/unit_tests/examples/few_shot_prompt_yaml_examples.yaml diff --git a/libs/langchain/tests/unit_tests/examples/jinja_injection_prompt.json b/libs/core/tests/unit_tests/examples/jinja_injection_prompt.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/jinja_injection_prompt.json rename to libs/core/tests/unit_tests/examples/jinja_injection_prompt.json diff --git a/libs/langchain/tests/unit_tests/examples/jinja_injection_prompt.yaml b/libs/core/tests/unit_tests/examples/jinja_injection_prompt.yaml similarity index 100% rename from libs/langchain/tests/unit_tests/examples/jinja_injection_prompt.yaml rename to libs/core/tests/unit_tests/examples/jinja_injection_prompt.yaml diff --git a/libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json b/libs/core/tests/unit_tests/examples/prompt_with_output_parser.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json rename to libs/core/tests/unit_tests/examples/prompt_with_output_parser.json diff --git a/libs/langchain/tests/unit_tests/examples/simple_prompt.json b/libs/core/tests/unit_tests/examples/simple_prompt.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/simple_prompt.json rename to libs/core/tests/unit_tests/examples/simple_prompt.json diff --git a/libs/langchain/tests/unit_tests/examples/simple_prompt.yaml b/libs/core/tests/unit_tests/examples/simple_prompt.yaml similarity index 100% rename from libs/langchain/tests/unit_tests/examples/simple_prompt.yaml rename to libs/core/tests/unit_tests/examples/simple_prompt.yaml diff --git a/libs/langchain/tests/unit_tests/examples/simple_prompt_with_template_file.json b/libs/core/tests/unit_tests/examples/simple_prompt_with_template_file.json similarity index 100% rename from libs/langchain/tests/unit_tests/examples/simple_prompt_with_template_file.json rename to libs/core/tests/unit_tests/examples/simple_prompt_with_template_file.json diff --git a/libs/langchain/tests/unit_tests/examples/simple_template.txt b/libs/core/tests/unit_tests/examples/simple_template.txt similarity index 100% rename from libs/langchain/tests/unit_tests/examples/simple_template.txt rename to libs/core/tests/unit_tests/examples/simple_template.txt diff --git a/libs/core/tests/unit_tests/fake/__init__.py b/libs/core/tests/unit_tests/fake/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py new file mode 100644 index 00000000000..aec0c2202ab --- /dev/null +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -0,0 +1,391 @@ +"""A fake callback handler for testing purposes.""" +from itertools import chain +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.messages import BaseMessage + + +class BaseFakeCallbackHandler(BaseModel): + """Base fake callback handler for testing.""" + + starts: int = 0 + ends: int = 0 + errors: int = 0 + text: int = 0 + ignore_llm_: bool = False + ignore_chain_: bool = False + ignore_agent_: bool = False + ignore_retriever_: bool = False + ignore_chat_model_: bool = False + + # to allow for similar callback handlers that are not technicall equal + fake_id: Union[str, None] = None + + # add finer-grained counters for easier debugging of failing tests + chain_starts: int = 0 + chain_ends: int = 0 + llm_starts: int = 0 + llm_ends: int = 0 + llm_streams: int = 0 + tool_starts: int = 0 + tool_ends: int = 0 + agent_actions: int = 0 + agent_ends: int = 0 + chat_model_starts: int = 0 + retriever_starts: int = 0 + retriever_ends: int = 0 + retriever_errors: int = 0 + retries: int = 0 + + +class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): + """Base fake callback handler mixin for testing.""" + + def on_llm_start_common(self) -> None: + self.llm_starts += 1 + self.starts += 1 + + def on_llm_end_common(self) -> None: + self.llm_ends += 1 + self.ends += 1 + + def on_llm_error_common(self) -> None: + self.errors += 1 + + def on_llm_new_token_common(self) -> None: + self.llm_streams += 1 + + def on_retry_common(self) -> None: + self.retries += 1 + + def on_chain_start_common(self) -> None: + self.chain_starts += 1 + self.starts += 1 + + def on_chain_end_common(self) -> None: + self.chain_ends += 1 + self.ends += 1 + + def on_chain_error_common(self) -> None: + self.errors += 1 + + def on_tool_start_common(self) -> None: + self.tool_starts += 1 + self.starts += 1 + + def on_tool_end_common(self) -> None: + self.tool_ends += 1 + self.ends += 1 + + def on_tool_error_common(self) -> None: + self.errors += 1 + + def on_agent_action_common(self) -> None: + self.agent_actions += 1 + self.starts += 1 + + def on_agent_finish_common(self) -> None: + self.agent_ends += 1 + self.ends += 1 + + def on_chat_model_start_common(self) -> None: + self.chat_model_starts += 1 + self.starts += 1 + + def on_text_common(self) -> None: + self.text += 1 + + def on_retriever_start_common(self) -> None: + self.starts += 1 + self.retriever_starts += 1 + + def on_retriever_end_common(self) -> None: + self.ends += 1 + self.retriever_ends += 1 + + def on_retriever_error_common(self) -> None: + self.errors += 1 + self.retriever_errors += 1 + + +class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake callback handler for testing.""" + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return self.ignore_retriever_ + + def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_start_common() + + def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_new_token_common() + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_end_common() + + def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_error_common() + + def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + + def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_start_common() + + def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_end_common() + + def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_error_common() + + def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_start_common() + + def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_end_common() + + def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_error_common() + + def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_action_common() + + def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_finish_common() + + def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_text_common() + + def on_retriever_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_start_common() + + def on_retriever_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_end_common() + + def on_retriever_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_error_common() + + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + return self + + +class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + assert all(isinstance(m, BaseMessage) for m in chain(*messages)) + self.on_chat_model_start_common() + + +class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake async callback handler for testing.""" + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + async def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + + async def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_start_common() + + async def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_new_token_common() + + async def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_end_common() + + async def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_error_common() + + async def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_start_common() + + async def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_end_common() + + async def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_error_common() + + async def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_start_common() + + async def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_end_common() + + async def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_error_common() + + async def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_action_common() + + async def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_finish_common() + + async def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_text_common() + + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + return self diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py new file mode 100644 index 00000000000..4a5a84064d0 --- /dev/null +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -0,0 +1,105 @@ +"""Fake ChatModel for testing purposes.""" +import asyncio +import time +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.chat_model import BaseChatModel, SimpleChatModel +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import AIMessageChunk, BaseMessage +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk + + +class FakeMessagesListChatModel(BaseChatModel): + """Fake ChatModel for testing purposes.""" + + responses: List[BaseMessage] + sleep: Optional[float] = None + i: int = 0 + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + generation = ChatGeneration(message=response) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "fake-messages-list-chat-model" + + +class FakeListChatModel(SimpleChatModel): + """Fake ChatModel for testing purposes.""" + + responses: List + sleep: Optional[float] = None + i: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-list-chat-model" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """First try to lookup in queries, else return 'foo' or 'bar'.""" + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + return response + + def _stream( + self, + messages: List[BaseMessage], + stop: Union[List[str], None] = None, + run_manager: Union[CallbackManagerForLLMRun, None] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + for c in response: + if self.sleep is not None: + time.sleep(self.sleep) + yield ChatGenerationChunk(message=AIMessageChunk(content=c)) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Union[List[str], None] = None, + run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + for c in response: + if self.sleep is not None: + await asyncio.sleep(self.sleep) + yield ChatGenerationChunk(message=AIMessageChunk(content=c)) + + @property + def _identifying_params(self) -> Dict[str, Any]: + return {"responses": self.responses} diff --git a/libs/core/tests/unit_tests/fake/llm.py b/libs/core/tests/unit_tests/fake/llm.py new file mode 100644 index 00000000000..fa1d92b1043 --- /dev/null +++ b/libs/core/tests/unit_tests/fake/llm.py @@ -0,0 +1,90 @@ +import asyncio +import time +from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional + +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.llm import LLM +from langchain_core.runnables import RunnableConfig +from langchain_core.schema.language_model import LanguageModelInput + + +class FakeListLLM(LLM): + """Fake LLM for testing purposes.""" + + responses: List[str] + sleep: Optional[float] = None + i: int = 0 + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fake-list" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Return next response""" + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + return response + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Return next response""" + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + return response + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {"responses": self.responses} + + +class FakeStreamingListLLM(FakeListLLM): + """Fake streaming list LLM for testing purposes.""" + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Iterator[str]: + result = self.invoke(input, config) + for c in result: + if self.sleep is not None: + time.sleep(self.sleep) + yield c + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + result = await self.ainvoke(input, config) + for c in result: + if self.sleep is not None: + await asyncio.sleep(self.sleep) + yield c diff --git a/libs/core/tests/unit_tests/fake/memory.py b/libs/core/tests/unit_tests/fake/memory.py new file mode 100644 index 00000000000..3dc5142e461 --- /dev/null +++ b/libs/core/tests/unit_tests/fake/memory.py @@ -0,0 +1,23 @@ +from typing import List + +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import ( + BaseChatMessageHistory, +) +from langchain_core.schema.messages import BaseMessage + + +class ChatMessageHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history. + + Stores messages in an in memory list. + """ + + messages: List[BaseMessage] = Field(default_factory=list) + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) + + def clear(self) -> None: + self.messages = [] diff --git a/libs/core/tests/unit_tests/prompt_file.txt b/libs/core/tests/unit_tests/prompt_file.txt new file mode 100644 index 00000000000..0681c36f48e --- /dev/null +++ b/libs/core/tests/unit_tests/prompt_file.txt @@ -0,0 +1,2 @@ +Question: {question} +Answer: \ No newline at end of file diff --git a/libs/langchain/tests/unit_tests/prompts/__init__.py b/libs/core/tests/unit_tests/prompts/__init__.py similarity index 100% rename from libs/langchain/tests/unit_tests/prompts/__init__.py rename to libs/core/tests/unit_tests/prompts/__init__.py diff --git a/libs/core/tests/unit_tests/prompts/prompt_extra_args.json b/libs/core/tests/unit_tests/prompts/prompt_extra_args.json new file mode 100644 index 00000000000..4bfc4fdcc4b --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/prompt_extra_args.json @@ -0,0 +1,5 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test.", + "bad_var": 1 +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/prompts/prompt_missing_args.json b/libs/core/tests/unit_tests/prompts/prompt_missing_args.json new file mode 100644 index 00000000000..cb69d843e7a --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/prompt_missing_args.json @@ -0,0 +1,3 @@ +{ + "input_variables": ["foo"] +} \ No newline at end of file diff --git a/libs/core/tests/unit_tests/prompts/simple_prompt.json b/libs/core/tests/unit_tests/prompts/simple_prompt.json new file mode 100644 index 00000000000..d0f72b1c14f --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/simple_prompt.json @@ -0,0 +1,4 @@ +{ + "input_variables": ["foo"], + "template": "This is a {foo} test." +} \ No newline at end of file diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py similarity index 98% rename from libs/langchain/tests/unit_tests/prompts/test_chat.py rename to libs/core/tests/unit_tests/prompts/test_chat.py index bbdca9d75ec..6272d0b2c91 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -3,8 +3,8 @@ from typing import Any, List, Union import pytest -from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ( +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( AIMessagePromptTemplate, BaseMessagePromptTemplate, ChatMessage, @@ -15,7 +15,7 @@ from langchain.prompts.chat import ( SystemMessagePromptTemplate, _convert_to_message, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, BaseMessage, HumanMessage, diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py similarity index 97% rename from libs/langchain/tests/unit_tests/prompts/test_few_shot.py rename to libs/core/tests/unit_tests/prompts/test_few_shot.py index 69e9f487b2f..5955ba8cb00 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -3,19 +3,19 @@ from typing import Any, Dict, List, Sequence, Tuple import pytest -from langchain.prompts import ( +from langchain_core.prompts import ( AIMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, ) -from langchain.prompts.chat import SystemMessagePromptTemplate -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.prompts.few_shot import ( +from langchain_core.prompts.chat import SystemMessagePromptTemplate +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.prompts.few_shot import ( FewShotChatMessagePromptTemplate, FewShotPromptTemplate, ) -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import AIMessage, HumanMessage, SystemMessage EXAMPLE_PROMPT = PromptTemplate( input_variables=["question", "answer"], template="{question}: {answer}" diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py b/libs/core/tests/unit_tests/prompts/test_few_shot_with_templates.py similarity index 93% rename from libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py rename to libs/core/tests/unit_tests/prompts/test_few_shot_with_templates.py index bf91eaaeb06..0ed3987b550 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot_with_templates.py @@ -2,8 +2,8 @@ import pytest -from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates +from langchain_core.prompts.prompt import PromptTemplate EXAMPLE_PROMPT = PromptTemplate( input_variables=["question", "answer"], template="{question}: {answer}" diff --git a/libs/langchain/tests/unit_tests/prompts/test_imports.py b/libs/core/tests/unit_tests/prompts/test_imports.py similarity index 90% rename from libs/langchain/tests/unit_tests/prompts/test_imports.py rename to libs/core/tests/unit_tests/prompts/test_imports.py index 6ec17789ad0..b70a3e6fc2a 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_imports.py +++ b/libs/core/tests/unit_tests/prompts/test_imports.py @@ -1,4 +1,4 @@ -from langchain.prompts import __all__ +from langchain_core.prompts import __all__ EXPECTED_ALL = [ "AIMessagePromptTemplate", @@ -12,7 +12,6 @@ EXPECTED_ALL = [ "LengthBasedExampleSelector", "MaxMarginalRelevanceExampleSelector", "MessagesPlaceholder", - "NGramOverlapExampleSelector", "PipelinePromptTemplate", "Prompt", "PromptTemplate", diff --git a/libs/langchain/tests/unit_tests/prompts/test_length_based_example_selector.py b/libs/core/tests/unit_tests/prompts/test_length_based_example_selector.py similarity index 92% rename from libs/langchain/tests/unit_tests/prompts/test_length_based_example_selector.py rename to libs/core/tests/unit_tests/prompts/test_length_based_example_selector.py index 38fd689c4e8..59e35f8f6e5 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_length_based_example_selector.py +++ b/libs/core/tests/unit_tests/prompts/test_length_based_example_selector.py @@ -1,8 +1,10 @@ """Test functionality related to length based selector.""" import pytest -from langchain.prompts.example_selector.length_based import LengthBasedExampleSelector -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.example_selector.length_based import ( + LengthBasedExampleSelector, +) +from langchain_core.prompts.prompt import PromptTemplate EXAMPLES = [ {"question": "Question: who are you?\nAnswer: foo"}, diff --git a/libs/langchain/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py similarity index 87% rename from libs/langchain/tests/unit_tests/prompts/test_loading.py rename to libs/core/tests/unit_tests/prompts/test_loading.py index 893ce4debd9..a89ffaf186c 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -6,10 +6,9 @@ from typing import Iterator import pytest -from langchain.output_parsers import RegexParser -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.loading import load_prompt -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.prompts.loading import load_prompt +from langchain_core.prompts.prompt import PromptTemplate EXAMPLE_DIR = Path("tests/unit_tests/examples").absolute() @@ -176,18 +175,3 @@ def test_loading_few_shot_prompt_example_prompt() -> None: suffix="Input: {adjective}\nOutput:", ) assert prompt == expected_prompt - - -def test_loading_with_output_parser() -> None: - with change_directory(EXAMPLE_DIR): - prompt = load_prompt("prompt_with_output_parser.json") - expected_template = "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:" # noqa: E501 - expected_prompt = PromptTemplate( - input_variables=["question", "student_answer"], - output_parser=RegexParser( - regex="(.*?)\nScore: (.*)", - output_keys=["answer", "score"], - ), - template=expected_template, - ) - assert prompt == expected_prompt diff --git a/libs/langchain/tests/unit_tests/prompts/test_pipeline_prompt.py b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py similarity index 88% rename from libs/langchain/tests/unit_tests/prompts/test_pipeline_prompt.py rename to libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py index d7ec03fec3f..3adca5c7b5c 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_pipeline_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py @@ -1,6 +1,6 @@ -from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder -from langchain.prompts.pipeline import PipelinePromptTemplate -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.prompts.pipeline import PipelinePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate def test_get_input_variables() -> None: diff --git a/libs/langchain/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py similarity index 99% rename from libs/langchain/tests/unit_tests/prompts/test_prompt.py rename to libs/core/tests/unit_tests/prompts/test_prompt.py index 8e522519495..f931dd80bc1 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,7 +1,7 @@ """Test functionality related to prompts.""" import pytest -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate def test_prompt_valid() -> None: diff --git a/libs/langchain/tests/unit_tests/prompts/test_utils.py b/libs/core/tests/unit_tests/prompts/test_utils.py similarity index 76% rename from libs/langchain/tests/unit_tests/prompts/test_utils.py rename to libs/core/tests/unit_tests/prompts/test_utils.py index 479d02e8bd9..ac3cf84e43e 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_utils.py +++ b/libs/core/tests/unit_tests/prompts/test_utils.py @@ -1,5 +1,5 @@ """Test functionality related to prompt utils.""" -from langchain.prompts.example_selector.semantic_similarity import sorted_values +from langchain_core.prompts.example_selector.semantic_similarity import sorted_values def test_sorted_vals() -> None: diff --git a/libs/core/tests/unit_tests/runnable/__init__.py b/libs/core/tests/unit_tests/runnable/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnable/__snapshots__/test_runnable.ambr similarity index 72% rename from libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr rename to libs/core/tests/unit_tests/runnable/__snapshots__/test_runnable.ambr index 2c349e96c0a..d91c0f5345a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnable/__snapshots__/test_runnable.ambr @@ -5,9 +5,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -15,7 +14,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -26,7 +25,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -36,7 +35,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -54,7 +53,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -64,7 +63,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -91,9 +90,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo, bar'])" @@ -103,7 +103,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -120,9 +120,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -130,9 +129,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -143,7 +141,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -154,7 +152,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -164,7 +162,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -182,7 +180,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -192,7 +190,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -218,9 +216,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['baz, qux'])" @@ -230,7 +229,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -247,9 +246,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -257,7 +255,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -268,7 +266,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -278,7 +276,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -296,7 +294,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -306,7 +304,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -333,9 +331,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo, bar'])" @@ -344,7 +343,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -355,9 +354,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -367,7 +365,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -378,7 +376,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -388,7 +386,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -406,7 +404,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -416,7 +414,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -442,9 +440,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['baz, qux'])" @@ -454,7 +453,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -467,7 +466,7 @@ # --- # name: test_combining_sequences.3 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda'], 'repr': "RunnableLambda(lambda x: {'question': x[0] + x[1]})"}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['baz, qux'])"}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar')}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda'], 'repr': "RunnableLambda(lambda x: {'question': x[0] + x[1]})"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': ['foo', 'bar']}, outputs={'question': 'foobar'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:4'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'foobar'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nicer assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'foobar', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:5'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['baz, qux'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['baz, qux'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nicer assistant.\nHuman: foobar']}, outputs={'generations': [[{'text': 'baz, qux', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'baz, qux'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:6'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='baz, qux')}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:7'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain_core', 'runnables', 'base', 'RunnableLambda'], 'repr': "RunnableLambda(lambda x: {'question': x[0] + x[1]})"}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['baz, qux'])"}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar')}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain_core', 'runnables', 'base', 'RunnableLambda'], 'repr': "RunnableLambda(lambda x: {'question': x[0] + x[1]})"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': ['foo', 'bar']}, outputs={'question': 'foobar'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:4'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'foobar'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nicer assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'foobar', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:5'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['baz, qux'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['baz, qux'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nicer assistant.\nHuman: foobar']}, outputs={'generations': [[{'text': 'baz, qux', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'baz, qux'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:6'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='baz, qux')}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:7'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_each @@ -476,9 +475,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -486,7 +484,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -497,7 +495,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -507,7 +505,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -525,7 +523,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -535,7 +533,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -562,9 +560,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeStreamingListLLM" ], "repr": "FakeStreamingListLLM(responses=['first item, second item, third item'])" @@ -575,7 +574,6 @@ "id": [ "tests", "unit_tests", - "schema", "runnable", "test_runnable", "FakeSplitIntoListParser" @@ -587,9 +585,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableEach" ], "kwargs": { @@ -597,9 +594,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeStreamingListLLM" ], "repr": "FakeStreamingListLLM(responses=['this', 'is', 'a', 'test'])" @@ -616,9 +614,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -626,9 +623,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -637,9 +633,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -649,9 +644,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -660,9 +654,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -678,9 +671,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -696,9 +688,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -706,9 +697,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -717,9 +707,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -732,9 +721,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableWithFallbacks" ], "kwargs": { @@ -742,9 +730,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -752,7 +739,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -770,9 +757,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['foo'], i=1)" @@ -784,9 +772,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -794,7 +781,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -812,9 +799,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['bar'])" @@ -845,9 +833,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableWithFallbacks" ], "kwargs": { @@ -855,9 +842,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['foo'], i=1)" @@ -867,9 +855,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['bar'])" @@ -896,9 +885,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableWithFallbacks" ], "kwargs": { @@ -906,9 +894,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['foo'], i=1)" @@ -918,9 +907,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['baz'], i=1)" @@ -929,9 +919,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['bar'])" @@ -964,9 +955,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -974,7 +964,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -985,7 +975,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -995,7 +985,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1013,7 +1003,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1023,7 +1013,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1049,9 +1039,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo'])" @@ -1062,7 +1053,7 @@ # --- # name: test_prompt_with_chat_model.2 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo')}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo')}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_prompt_with_chat_model_and_parser @@ -1071,9 +1062,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1081,7 +1071,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1092,7 +1082,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -1102,7 +1092,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1120,7 +1110,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1130,7 +1120,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1157,9 +1147,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo, bar'])" @@ -1169,7 +1160,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -1182,7 +1173,7 @@ # --- # name: test_prompt_with_chat_model_and_parser.1 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar')}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo, bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar')}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_prompt_with_chat_model_async @@ -1197,9 +1188,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1207,7 +1197,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1218,7 +1208,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -1228,7 +1218,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1246,7 +1236,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1256,7 +1246,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1282,9 +1272,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo'])" @@ -1295,7 +1286,7 @@ # --- # name: test_prompt_with_chat_model_async.2 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo')}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': AIMessage(content='foo')}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'chat_model', 'FakeListChatModel'], 'repr': "FakeListChatModel(responses=['foo'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'ChatGeneration', 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_prompt_with_llm @@ -1304,9 +1295,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1314,7 +1304,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1325,7 +1315,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -1335,7 +1325,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1353,7 +1343,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1363,7 +1353,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1389,9 +1379,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['foo', 'bar'])" @@ -1402,13 +1393,13 @@ # --- # name: test_prompt_with_llm.1 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_prompt_with_llm.2 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'bar'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'bar', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), - Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your favorite color?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'bar'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'bar', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your favorite color?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'], i=1)"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_prompt_with_llm_and_async_lambda @@ -1417,9 +1408,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1427,7 +1417,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1438,7 +1428,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -1448,7 +1438,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1466,7 +1456,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1476,7 +1466,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1503,9 +1493,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['foo', 'bar'])" @@ -1515,9 +1506,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -1529,7 +1519,7 @@ # --- # name: test_prompt_with_llm_and_async_lambda.1 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda'], 'repr': 'RunnableLambda(afunc=...)'}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='passthrough', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda'], 'repr': 'RunnableLambda(afunc=...)'}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'runnables', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain_core', 'runnables', 'base', 'RunnableLambda'], 'repr': 'RunnableLambda(afunc=...)'}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain_core', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:1'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['tests', 'unit_tests', 'fake', 'llm', 'FakeListLLM'], 'repr': "FakeListLLM(responses=['foo', 'bar'])"}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None, 'type': 'Generation'}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:2'], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='passthrough', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain_core', 'runnables', 'base', 'RunnableLambda'], 'repr': 'RunnableLambda(afunc=...)'}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=['seq:step:3'], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_router_runnable @@ -1538,9 +1528,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1548,9 +1537,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -1559,9 +1547,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -1571,9 +1558,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -1582,9 +1568,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -1600,9 +1585,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RouterRunnable" ], "kwargs": { @@ -1611,9 +1595,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1621,7 +1604,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1635,7 +1618,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1645,7 +1628,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1668,9 +1651,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['4'])" @@ -1681,9 +1665,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -1691,7 +1674,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -1705,7 +1688,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -1715,7 +1698,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -1738,9 +1721,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=['2'])" @@ -4316,9 +4300,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -4326,9 +4309,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -4337,9 +4319,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -4347,9 +4328,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnablePassthrough" ], "kwargs": { @@ -4362,9 +4342,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -4376,9 +4355,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -4386,9 +4364,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -4400,7 +4377,6 @@ "id": [ "tests", "unit_tests", - "schema", "runnable", "test_runnable", "FakeRetriever" @@ -4413,9 +4389,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -4429,7 +4404,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -4440,7 +4415,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -4450,7 +4425,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4468,7 +4443,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -4478,7 +4453,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4506,9 +4481,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=['foo, bar'])" @@ -4518,7 +4494,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "output_parsers", "list", "CommaSeparatedListOutputParser" @@ -4545,9 +4521,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -4555,7 +4530,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -4566,7 +4541,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -4576,7 +4551,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4594,7 +4569,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -4604,7 +4579,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4631,9 +4606,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -4644,9 +4618,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -4655,9 +4628,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])" @@ -4666,9 +4640,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=[\"i'm a textbot\"])" @@ -4686,9 +4661,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableSequence" ], "kwargs": { @@ -4696,7 +4670,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -4707,7 +4681,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "SystemMessagePromptTemplate" @@ -4717,7 +4691,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4735,7 +4709,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -4745,7 +4719,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -4772,9 +4746,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], @@ -4785,9 +4758,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableParallel" ], "kwargs": { @@ -4796,9 +4768,8 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "RunnableBinding" ], "kwargs": { @@ -4806,9 +4777,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "chat_models", + "tests", + "unit_tests", "fake", + "chat_model", "FakeListChatModel" ], "repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])" @@ -4828,9 +4800,10 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "llms", + "tests", + "unit_tests", "fake", + "llm", "FakeListLLM" ], "repr": "FakeListLLM(responses=[\"i'm a textbot\"])" @@ -4839,9 +4812,8 @@ "lc": 1, "type": "not_implemented", "id": [ - "langchain", - "schema", - "runnable", + "langchain_core", + "runnables", "base", "RunnableLambda" ], diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_config.py b/libs/core/tests/unit_tests/runnable/test_config.py similarity index 75% rename from libs/langchain/tests/unit_tests/schema/runnable/test_config.py rename to libs/core/tests/unit_tests/runnable/test_config.py index 410fee06108..e15c5e48a4a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_config.py +++ b/libs/core/tests/unit_tests/runnable/test_config.py @@ -1,8 +1,8 @@ -from langchain.callbacks.manager import CallbackManager -from langchain.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler -from langchain.schema.runnable.config import RunnableConfig, merge_configs +from langchain_core.callbacks.manager import CallbackManager +from langchain_core.callbacks.stdout import StdOutCallbackHandler +from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain_core.runnables.config import RunnableConfig, merge_configs def test_merge_config_callbacks() -> None: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_history.py b/libs/core/tests/unit_tests/runnable/test_history.py similarity index 95% rename from libs/langchain/tests/unit_tests/schema/runnable/test_history.py rename to libs/core/tests/unit_tests/runnable/test_history.py index b19661a6265..534ff12b44c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_history.py +++ b/libs/core/tests/unit_tests/runnable/test_history.py @@ -1,10 +1,10 @@ from typing import Any, Callable, Sequence, Union -from langchain.memory import ChatMessageHistory -from langchain.pydantic_v1 import BaseModel -from langchain.schema import AIMessage, BaseMessage, HumanMessage -from langchain.schema.runnable import RunnableConfig, RunnableLambda -from langchain.schema.runnable.history import RunnableWithMessageHistory +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.schema import AIMessage, BaseMessage, HumanMessage +from tests.unit_tests.fake.memory import ChatMessageHistory def _get_get_session_history() -> Callable[..., ChatMessageHistory]: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/core/tests/unit_tests/runnable/test_runnable.py similarity index 88% rename from libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py rename to libs/core/tests/unit_tests/runnable/test_runnable.py index 56af1734cd3..21c5a28b689 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/core/tests/unit_tests/runnable/test_runnable.py @@ -20,42 +20,29 @@ from pytest_mock import MockerFixture from syrupy import SnapshotAssertion from typing_extensions import TypedDict -from langchain.callbacks.manager import ( +from langchain_core.callbacks.manager import ( Callbacks, atrace_as_chain_group, collect_runs, trace_as_chain_group, ) -from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch -from langchain.callbacks.tracers.schemas import Run -from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler -from langchain.chains.question_answering import load_qa_chain -from langchain.chains.summarize import load_summarize_chain -from langchain.chat_models.fake import FakeListChatModel -from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM -from langchain.load.dump import dumpd, dumps -from langchain.output_parsers.list import CommaSeparatedListOutputParser -from langchain.prompts import PromptTemplate -from langchain.prompts.base import StringPromptValue -from langchain.prompts.chat import ( +from langchain_core.callbacks.tracers.base import BaseTracer +from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch +from langchain_core.callbacks.tracers.schemas import Run +from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain_core.load.dump import dumpd, dumps +from langchain_core.output_parsers.list import CommaSeparatedListOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.base import StringPromptValue +from langchain_core.prompts.chat import ( ChatPromptTemplate, ChatPromptValue, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain.pydantic_v1 import BaseModel -from langchain.schema.document import Document -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - HumanMessage, - SystemMessage, -) -from langchain.schema.output_parser import BaseOutputParser, StrOutputParser -from langchain.schema.retriever import BaseRetriever -from langchain.schema.runnable import ( +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import ( RouterRunnable, Runnable, RunnableBranch, @@ -66,18 +53,28 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) -from langchain.schema.runnable.base import ( +from langchain_core.runnables.base import ( ConfigurableField, RunnableBinding, RunnableGenerator, ) -from langchain.schema.runnable.utils import ( +from langchain_core.runnables.utils import ( ConfigurableFieldMultiOption, ConfigurableFieldSingleOption, add, ) -from langchain.tools.base import BaseTool, tool -from langchain.tools.json.tool import JsonListKeysTool, JsonSpec +from langchain_core.schema.document import Document +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, +) +from langchain_core.schema.output_parser import BaseOutputParser, StrOutputParser +from langchain_core.schema.retriever import BaseRetriever +from langchain_core.tool import BaseTool, tool +from tests.unit_tests.fake.chat_model import FakeListChatModel +from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM class FakeTracer(BaseTracer): @@ -583,18 +580,6 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: }, } - json_list_keys_tool = JsonListKeysTool(spec=JsonSpec(dict_={})) - - assert json_list_keys_tool.input_schema.schema() == { - "title": "json_spec_list_keysSchema", - "type": "object", - "properties": {"tool_input": {"title": "Tool Input", "type": "string"}}, - "required": ["tool_input"], - } - assert json_list_keys_tool.output_schema.schema() == { - "title": "JsonListKeysToolOutput" - } - def test_passthrough_assign_schema() -> None: retriever = FakeRetriever() # str -> List[Document] @@ -814,133 +799,6 @@ def test_schema_complex_seq() -> None: } -def test_schema_chains() -> None: - model = FakeListChatModel(responses=[""]) - - stuff_chain = load_summarize_chain(model) - - assert stuff_chain.input_schema.schema() == { - "title": "CombineDocumentsInput", - "type": "object", - "properties": { - "input_documents": { - "title": "Input Documents", - "type": "array", - "items": {"$ref": "#/definitions/Document"}, - } - }, - "definitions": { - "Document": { - "title": "Document", - "description": "Class for storing a piece of text and associated metadata.", # noqa: E501 - "type": "object", - "properties": { - "page_content": {"title": "Page Content", "type": "string"}, - "metadata": {"title": "Metadata", "type": "object"}, - "type": { - "title": "Type", - "type": "string", - "enum": ["Document"], - "default": "Document", - }, - }, - "required": ["page_content"], - } - }, - } - assert stuff_chain.output_schema.schema() == { - "title": "CombineDocumentsOutput", - "type": "object", - "properties": {"output_text": {"title": "Output Text", "type": "string"}}, - } - - mapreduce_chain = load_summarize_chain( - model, "map_reduce", return_intermediate_steps=True - ) - - assert mapreduce_chain.input_schema.schema() == { - "title": "CombineDocumentsInput", - "type": "object", - "properties": { - "input_documents": { - "title": "Input Documents", - "type": "array", - "items": {"$ref": "#/definitions/Document"}, - } - }, - "definitions": { - "Document": { - "title": "Document", - "description": "Class for storing a piece of text and associated metadata.", # noqa: E501 - "type": "object", - "properties": { - "page_content": {"title": "Page Content", "type": "string"}, - "metadata": {"title": "Metadata", "type": "object"}, - "type": { - "title": "Type", - "type": "string", - "enum": ["Document"], - "default": "Document", - }, - }, - "required": ["page_content"], - } - }, - } - assert mapreduce_chain.output_schema.schema() == { - "title": "MapReduceDocumentsOutput", - "type": "object", - "properties": { - "output_text": {"title": "Output Text", "type": "string"}, - "intermediate_steps": { - "title": "Intermediate Steps", - "type": "array", - "items": {"type": "string"}, - }, - }, - } - - maprerank_chain = load_qa_chain(model, "map_rerank", metadata_keys=["hello"]) - - assert maprerank_chain.input_schema.schema() == { - "title": "CombineDocumentsInput", - "type": "object", - "properties": { - "input_documents": { - "title": "Input Documents", - "type": "array", - "items": {"$ref": "#/definitions/Document"}, - } - }, - "definitions": { - "Document": { - "title": "Document", - "description": "Class for storing a piece of text and associated metadata.", # noqa: E501 - "type": "object", - "properties": { - "page_content": {"title": "Page Content", "type": "string"}, - "metadata": {"title": "Metadata", "type": "object"}, - "type": { - "title": "Type", - "type": "string", - "enum": ["Document"], - "default": "Document", - }, - }, - "required": ["page_content"], - } - }, - } - assert maprerank_chain.output_schema.schema() == { - "title": "MapRerankOutput", - "type": "object", - "properties": { - "output_text": {"title": "Output Text", "type": "string"}, - "hello": {"title": "Hello"}, - }, - } - - def test_configurable_fields() -> None: fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]] @@ -1281,7 +1139,6 @@ def test_configurable_fields_example() -> None: ) -@pytest.mark.asyncio async def test_passthrough_tap_async(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() @@ -1309,7 +1166,6 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() -@pytest.mark.asyncio async def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1427,7 +1283,6 @@ async def test_with_config(mocker: MockerFixture) -> None: ] -@pytest.mark.asyncio async def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1506,7 +1361,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: ] -@pytest.mark.asyncio async def test_prompt() -> None: prompt = ChatPromptTemplate.from_messages( messages=[ @@ -1711,7 +1565,6 @@ def test_with_listeners(mocker: MockerFixture) -> None: assert mock_end.call_count == 1 -@pytest.mark.asyncio async def test_with_listeners_async(mocker: MockerFixture) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -1849,7 +1702,6 @@ def test_prompt_with_chat_model( ) -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_chat_model_async( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -1957,7 +1809,6 @@ async def test_prompt_with_chat_model_async( ) -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -2150,7 +2001,6 @@ async def test_prompt_with_llm( ] -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_stream_log_retriever() -> None: prompt = ( @@ -2180,293 +2030,6 @@ async def test_stream_log_retriever() -> None: ): del op["value"]["id"] - assert stream_log[:-9] in [ - [ - RunLogPatch( - { - "op": "replace", - "path": "", - "value": { - "logs": {}, - "final_output": None, - "streamed_output": [], - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableParallel", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "RunnableParallel", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["seq:step:1"], - "type": "chain", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableLambda", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "RunnableLambda", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["map:key:question"], - "type": "chain", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableLambda/final_output", - "value": {"output": "What is your name?"}, - }, - { - "op": "add", - "path": "/logs/RunnableLambda/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/Retriever", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "Retriever", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["map:key:documents"], - "type": "retriever", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/Retriever/final_output", - "value": { - "documents": [ - Document(page_content="foo"), - Document(page_content="bar"), - ] - }, - }, - { - "op": "add", - "path": "/logs/Retriever/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableParallel/final_output", - "value": { - "documents": [ - Document(page_content="foo"), - Document(page_content="bar"), - ], - "question": "What is your name?", - }, - }, - { - "op": "add", - "path": "/logs/RunnableParallel/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/ChatPromptTemplate", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "ChatPromptTemplate", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["seq:step:2"], - "type": "prompt", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/ChatPromptTemplate/final_output", - "value": ChatPromptValue( - messages=[ - SystemMessage(content="You are a nice assistant."), - HumanMessage( - content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501 - ), - HumanMessage(content="What is your name?"), - ] - ), - }, - { - "op": "add", - "path": "/logs/ChatPromptTemplate/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - ], - [ - RunLogPatch( - { - "op": "replace", - "path": "", - "value": {"final_output": None, "logs": {}, "streamed_output": []}, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableParallel", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "RunnableParallel", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["seq:step:1"], - "type": "chain", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/Retriever", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "Retriever", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["map:key:documents"], - "type": "retriever", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableLambda", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "RunnableLambda", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["map:key:question"], - "type": "chain", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableLambda/final_output", - "value": {"output": "What is your name?"}, - }, - { - "op": "add", - "path": "/logs/RunnableLambda/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/Retriever/final_output", - "value": { - "documents": [ - Document(page_content="foo"), - Document(page_content="bar"), - ] - }, - }, - { - "op": "add", - "path": "/logs/Retriever/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/RunnableParallel/final_output", - "value": { - "documents": [ - Document(page_content="foo"), - Document(page_content="bar"), - ], - "question": "What is your name?", - }, - }, - { - "op": "add", - "path": "/logs/RunnableParallel/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/ChatPromptTemplate", - "value": { - "end_time": None, - "final_output": None, - "metadata": {}, - "name": "ChatPromptTemplate", - "start_time": "2023-01-01T00:00:00.000", - "streamed_output_str": [], - "tags": ["seq:step:2"], - "type": "prompt", - }, - } - ), - RunLogPatch( - { - "op": "add", - "path": "/logs/ChatPromptTemplate/final_output", - "value": ChatPromptValue( - messages=[ - SystemMessage(content="You are a nice assistant."), - HumanMessage( - content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501 - ), - HumanMessage(content="What is your name?"), - ] - ), - }, - { - "op": "add", - "path": "/logs/ChatPromptTemplate/end_time", - "value": "2023-01-01T00:00:00.000", - }, - ), - ], - ] - assert sorted(cast(RunLog, add(stream_log)).state["logs"]) == [ "ChatPromptTemplate", "FakeListLLM", @@ -2478,7 +2041,6 @@ async def test_stream_log_retriever() -> None: ] -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm_and_async_lambda( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -2759,7 +2321,6 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert len(map_run.child_runs) == 2 -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_router_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -2812,7 +2373,6 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -3071,7 +2631,6 @@ def test_map_stream_iterator_input() -> None: assert final_value.get("passthrough") == "i'm a textbot" -@pytest.mark.asyncio async def test_map_astream() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3194,7 +2753,6 @@ async def test_map_astream() -> None: ] -@pytest.mark.asyncio async def test_map_astream_iterator_input() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3411,7 +2969,6 @@ def test_deep_stream_assign() -> None: } -@pytest.mark.asyncio async def test_deep_astream() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3438,7 +2995,6 @@ async def test_deep_astream() -> None: assert "".join(chunks) == "foo-lish" -@pytest.mark.asyncio async def test_deep_astream_assign() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3552,7 +3108,6 @@ def test_runnable_sequence_transform() -> None: assert "".join(chunks) == "foo-lish" -@pytest.mark.asyncio async def test_runnable_sequence_atransform() -> None: llm = FakeStreamingListLLM(responses=["foo-lish"]) @@ -3600,7 +3155,6 @@ def llm_chain_with_fallbacks() -> Runnable: "runnable", ["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"], ) -@pytest.mark.asyncio async def test_llm_with_fallbacks( runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion ) -> None: @@ -3743,7 +3297,6 @@ def test_retrying(mocker: MockerFixture) -> None: _lambda_mock.reset_mock() -@pytest.mark.asyncio async def test_async_retrying(mocker: MockerFixture) -> None: def _lambda(x: int) -> Union[int, Runnable]: if x == 1: @@ -3928,7 +3481,6 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None] -@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: class ControlledExceptionRunnable(Runnable[str, str]): @@ -4158,7 +3710,6 @@ def test_runnable_branch_batch() -> None: assert branch.batch([1, 10, 0]) == [2, 100, -1] -@pytest.mark.asyncio async def test_runnable_branch_ainvoke() -> None: """Test async variant of invoke.""" branch = RunnableBranch[int, int]( @@ -4214,7 +3765,6 @@ def test_runnable_branch_invoke_callbacks() -> None: assert tracer.runs[1].outputs is None -@pytest.mark.asyncio async def test_runnable_branch_ainvoke_callbacks() -> None: """Verify that callbacks are invoked correctly in ainvoke.""" tracer = FakeTracer() @@ -4242,7 +3792,6 @@ async def test_runnable_branch_ainvoke_callbacks() -> None: assert tracer.runs[1].outputs is None -@pytest.mark.asyncio async def test_runnable_branch_abatch() -> None: """Test async variant of invoke.""" branch = RunnableBranch[int, int]( @@ -4289,7 +3838,6 @@ def test_representation_of_runnables() -> None: ), "repr where code string contains multiple lambdas gives up" -@pytest.mark.asyncio async def test_tool_from_runnable() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -4318,7 +3866,6 @@ async def test_tool_from_runnable() -> None: } -@pytest.mark.asyncio async def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" @@ -4351,7 +3898,6 @@ async def test_runnable_gen() -> None: assert await arunnable.abatch([None, None]) == [6, 6] -@pytest.mark.asyncio async def test_runnable_gen_transform() -> None: """Test that a generator can be used as a runnable.""" diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_utils.py b/libs/core/tests/unit_tests/runnable/test_utils.py similarity index 95% rename from libs/langchain/tests/unit_tests/schema/runnable/test_utils.py rename to libs/core/tests/unit_tests/runnable/test_utils.py index 30f9e23bb26..1bbf5a8a91a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_utils.py +++ b/libs/core/tests/unit_tests/runnable/test_utils.py @@ -3,7 +3,7 @@ from typing import Callable import pytest -from langchain.schema.runnable.utils import ( +from langchain_core.runnables.utils import ( get_lambda_source, indent_lines_after_first, ) diff --git a/libs/core/tests/unit_tests/schema/__init__.py b/libs/core/tests/unit_tests/schema/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/schema/test_imports.py b/libs/core/tests/unit_tests/schema/test_imports.py new file mode 100644 index 00000000000..5bc2f228798 --- /dev/null +++ b/libs/core/tests/unit_tests/schema/test_imports.py @@ -0,0 +1,43 @@ +from langchain_core.schema import __all__ + +EXPECTED_ALL = [ + "BaseCache", + "BaseMemory", + "BaseStore", + "AgentFinish", + "AgentAction", + "Document", + "BaseChatMessageHistory", + "BaseDocumentTransformer", + "BaseMessage", + "ChatMessage", + "FunctionMessage", + "HumanMessage", + "AIMessage", + "SystemMessage", + "messages_from_dict", + "messages_to_dict", + "_message_to_dict", + "_message_from_dict", + "get_buffer_string", + "RunInfo", + "LLMResult", + "ChatResult", + "ChatGeneration", + "Generation", + "PromptValue", + "LangChainException", + "BaseRetriever", + "RUN_KEY", + "Memory", + "OutputParserException", + "StrOutputParser", + "BaseOutputParser", + "BaseLLMOutputParser", + "BasePromptTemplate", + "format_document", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/core/tests/unit_tests/schema/test_messages.py b/libs/core/tests/unit_tests/schema/test_messages.py new file mode 100644 index 00000000000..8c263e6ed91 --- /dev/null +++ b/libs/core/tests/unit_tests/schema/test_messages.py @@ -0,0 +1,102 @@ +import pytest + +from langchain_core.schema.messages import ( + AIMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, +) + + +def test_message_chunks() -> None: + assert AIMessageChunk(content="I am") + AIMessageChunk( + content=" indeed." + ) == AIMessageChunk( + content="I am indeed." + ), "MessageChunk + MessageChunk should be a MessageChunk" + + assert ( + AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.") + == AIMessageChunk(content="I am indeed.") + ), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501 + + assert ( + AIMessageChunk(content="", additional_kwargs={"foo": "bar"}) + + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) + == AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"}) + ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + + assert ( + AIMessageChunk( + content="", additional_kwargs={"function_call": {"name": "web_search"}} + ) + + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": "{\n"}} + ) + + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {"arguments": ' "query": "turtles"\n}'} + }, + ) + == AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "web_search", + "arguments": '{\n "query": "turtles"\n}', + } + }, + ) + ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + + +def test_chat_message_chunks() -> None: + assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( + role="User", content=" indeed." + ) == ChatMessageChunk( + role="User", content="I am indeed." + ), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk" + + with pytest.raises(ValueError): + ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( + role="Assistant", content=" indeed." + ) + + assert ( + ChatMessageChunk(role="User", content="I am") + + AIMessageChunk(content=" indeed.") + == ChatMessageChunk(role="User", content="I am indeed.") + ), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501 + + assert AIMessageChunk(content="I am") + ChatMessageChunk( + role="User", content=" indeed." + ) == AIMessageChunk( + content="I am indeed." + ), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501 + + +def test_function_message_chunks() -> None: + assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( + name="hello", content=" indeed." + ) == FunctionMessageChunk( + name="hello", content="I am indeed." + ), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk" + + with pytest.raises(ValueError): + FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( + name="bye", content=" indeed." + ) + + +def test_ani_message_chunks() -> None: + assert AIMessageChunk(example=True, content="I am") + AIMessageChunk( + example=True, content=" indeed." + ) == AIMessageChunk( + example=True, content="I am indeed." + ), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk" + + with pytest.raises(ValueError): + AIMessageChunk(example=True, content="I am") + AIMessageChunk( + example=False, content=" indeed." + ) diff --git a/libs/core/tests/unit_tests/schema/test_output.py b/libs/core/tests/unit_tests/schema/test_output.py new file mode 100644 index 00000000000..5e086c5e5a3 --- /dev/null +++ b/libs/core/tests/unit_tests/schema/test_output.py @@ -0,0 +1,60 @@ +from langchain_core.schema.messages import HumanMessageChunk +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk + + +def test_generation_chunk() -> None: + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!" + ) == GenerationChunk( + text="Hello, world!" + ), "GenerationChunk + GenerationChunk should be a GenerationChunk" + + assert ( + GenerationChunk(text="Hello, ") + + GenerationChunk(text="world!", generation_info={"foo": "bar"}) + == GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"}) + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + assert ( + GenerationChunk(text="Hello, ") + + GenerationChunk(text="world!", generation_info={"foo": "bar"}) + + GenerationChunk(text="!", generation_info={"baz": "foo"}) + == GenerationChunk( + text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"} + ) + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + +def test_chat_generation_chunk() -> None: + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!") + ) == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!") + ), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk" + + assert ( + ChatGenerationChunk(message=HumanMessageChunk(content="Hello, ")) + + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) + == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!"), + generation_info={"foo": "bar"}, + ) + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + + assert ( + ChatGenerationChunk(message=HumanMessageChunk(content="Hello, ")) + + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) + + ChatGenerationChunk( + message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"} + ) + == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!!"), + generation_info={"foo": "bar", "baz": "foo"}, + ) + ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 diff --git a/libs/core/tests/unit_tests/test_globals.py b/libs/core/tests/unit_tests/test_globals.py new file mode 100644 index 00000000000..760b9d63c59 --- /dev/null +++ b/libs/core/tests/unit_tests/test_globals.py @@ -0,0 +1,31 @@ +from langchain_core.globals import get_debug, set_debug + + +def test_debug_is_settable_via_setter() -> None: + from langchain_core import globals + from langchain_core.callbacks.manager import _get_debug + + previous_value = globals._debug + previous_fn_reading = _get_debug() + assert previous_value == previous_fn_reading + + # Flip the value of the flag. + set_debug(not previous_value) + + new_value = globals._debug + new_fn_reading = _get_debug() + + try: + # We successfully changed the value of `debug`. + assert new_value != previous_value + + # If we access `debug` via a function used elsewhere in langchain, + # it also sees the same new value. + assert new_value == new_fn_reading + + # If we access `debug` via `get_debug()` we also get the same value. + assert new_value == get_debug() + finally: + # Make sure we don't alter global state, even if the test fails. + # Always reset `debug` to the value it had before. + set_debug(previous_value) diff --git a/libs/langchain/tests/unit_tests/tools/test_base.py b/libs/core/tests/unit_tests/test_tool.py similarity index 98% rename from libs/langchain/tests/unit_tests/tools/test_base.py rename to libs/core/tests/unit_tests/test_tool.py index 9c5fdf39e10..1f63798925e 100644 --- a/libs/langchain/tests/unit_tests/tools/test_base.py +++ b/libs/core/tests/unit_tests/test_tool.py @@ -7,19 +7,20 @@ from typing import Any, List, Optional, Type, Union import pytest -from langchain.agents.tools import Tool, tool -from langchain.callbacks.manager import ( +from langchain_core.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel -from langchain.tools.base import ( +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tool import ( BaseTool, SchemaAnnotationError, StructuredTool, + Tool, ToolException, + tool, ) -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.fake.callbacks import FakeCallbackHandler def test_unnamed_decorator() -> None: @@ -571,7 +572,6 @@ def test_create_tool_keyword_args() -> None: assert test_tool.description == "test_description" -@pytest.mark.asyncio async def test_create_async_tool() -> None: """Test that async tools are allowed.""" @@ -632,7 +632,6 @@ def test_exception_handling_non_tool_exception() -> None: _tool.run({}) -@pytest.mark.asyncio async def test_async_exception_handling_bool() -> None: _tool = _FakeExceptionTool(handle_tool_error=True) expected = "Tool execution error" @@ -640,7 +639,6 @@ async def test_async_exception_handling_bool() -> None: assert expected == actual -@pytest.mark.asyncio async def test_async_exception_handling_str() -> None: expected = "foo bar" _tool = _FakeExceptionTool(handle_tool_error=expected) @@ -648,7 +646,6 @@ async def test_async_exception_handling_str() -> None: assert expected == actual -@pytest.mark.asyncio async def test_async_exception_handling_callable() -> None: expected = "foo bar" handling = lambda _: expected # noqa: E731 @@ -657,7 +654,6 @@ async def test_async_exception_handling_callable() -> None: assert expected == actual -@pytest.mark.asyncio async def test_async_exception_handling_non_tool_exception() -> None: _tool = _FakeExceptionTool(exception=ValueError()) with pytest.raises(ValueError): diff --git a/libs/core/tests/unit_tests/utils/__init__.py b/libs/core/tests/unit_tests/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py new file mode 100644 index 00000000000..91e5c016e7d --- /dev/null +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -0,0 +1,21 @@ +from langchain_core.utils import __all__ + +EXPECTED_ALL = [ + "StrictFormatter", + "check_package_version", + "convert_to_secret_str", + "formatter", + "get_bolded_text", + "get_color_mapping", + "get_colored_text", + "get_pydantic_field_names", + "guard_import", + "mock_now", + "print_text", + "raise_for_status_with_text", + "xor_args", +] + + +def test_all_imports() -> None: + assert set(__all__) == set(EXPECTED_ALL) diff --git a/libs/core/tests/unit_tests/utils/test_iter.py b/libs/core/tests/unit_tests/utils/test_iter.py new file mode 100644 index 00000000000..d0866ea3fc0 --- /dev/null +++ b/libs/core/tests/unit_tests/utils/test_iter.py @@ -0,0 +1,21 @@ +from typing import List + +import pytest + +from langchain_core.utils.iter import batch_iterate + + +@pytest.mark.parametrize( + "input_size, input_iterable, expected_output", + [ + (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), + (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]), + (1, [100, 200, 300], [[100], [200], [300]]), + (4, [], []), + ], +) +def test_batch_iterate( + input_size: int, input_iterable: List[str], expected_output: List[str] +) -> None: + """Test batching function.""" + assert list(batch_iterate(input_size, input_iterable)) == expected_output diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index 6a15cb7c7ab..d927f62d671 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -25,7 +25,10 @@ extended_tests: poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests test_watch: - poetry run ptw --now . -- tests/unit_tests + poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests + +test_watch_extended: + poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests integration_tests: poetry run pytest tests/integration_tests diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index b6c32407e4b..8f47b520aa5 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -4,7 +4,7 @@ import warnings from importlib import metadata from typing import Any, Optional -from langchain._api.deprecation import surface_langchain_deprecation_warnings +from langchain_core._api.deprecation import surface_langchain_deprecation_warnings try: __version__ = metadata.version(__package__) @@ -233,25 +233,25 @@ def __getattr__(name: str) -> Any: return HuggingFacePipeline elif name == "FewShotPromptTemplate": - from langchain.prompts import FewShotPromptTemplate + from langchain_core.prompts import FewShotPromptTemplate _warn_on_import(name, replacement="langchain.prompts.FewShotPromptTemplate") return FewShotPromptTemplate elif name == "Prompt": - from langchain.prompts import Prompt + from langchain_core.prompts import Prompt _warn_on_import(name, replacement="langchain.prompts.Prompt") return Prompt elif name == "PromptTemplate": - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate _warn_on_import(name, replacement="langchain.prompts.PromptTemplate") return PromptTemplate elif name == "BasePromptTemplate": - from langchain.schema.prompt_template import BasePromptTemplate + from langchain_core.schema.prompt_template import BasePromptTemplate _warn_on_import( name, replacement="langchain.schema.prompt_template.BasePromptTemplate" diff --git a/libs/langchain/langchain/_api/deprecation.py b/libs/langchain/langchain/_api/deprecation.py index 6919504351d..e85ab4046ce 100644 --- a/libs/langchain/langchain/_api/deprecation.py +++ b/libs/langchain/langchain/_api/deprecation.py @@ -1,341 +1,17 @@ -"""Helper functions for deprecating parts of the LangChain API. +from langchain_core._api.deprecation import ( + LangChainDeprecationWarning, + LangChainPendingDeprecationWarning, + deprecated, + suppress_langchain_deprecation_warning, + surface_langchain_deprecation_warnings, + warn_deprecated, +) -This module was adapted from matplotlibs _api/deprecation.py module: - -https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py - -.. warning:: - - This module is for internal use only. Do not use it in your own code. - We may change the API at any time with no warning. -""" - -import contextlib -import functools -import inspect -import warnings -from typing import Any, Callable, Generator, Type, TypeVar - - -class LangChainDeprecationWarning(DeprecationWarning): - """A class for issuing deprecation warnings for LangChain users.""" - - -class LangChainPendingDeprecationWarning(PendingDeprecationWarning): - """A class for issuing deprecation warnings for LangChain users.""" - - -# PUBLIC API - - -T = TypeVar("T", Type, Callable) - - -def deprecated( - since: str, - *, - message: str = "", - name: str = "", - alternative: str = "", - pending: bool = False, - obj_type: str = "", - addendum: str = "", - removal: str = "", -) -> Callable[[T], T]: - """Decorator to mark a function, a class, or a property as deprecated. - - When deprecating a classmethod, a staticmethod, or a property, the - ``@deprecated`` decorator should go *under* ``@classmethod`` and - ``@staticmethod`` (i.e., `deprecated` should directly decorate the - underlying callable), but *over* ``@property``. - - When deprecating a class ``C`` intended to be used as a base class in a - multiple inheritance hierarchy, ``C`` *must* define an ``__init__`` method - (if ``C`` instead inherited its ``__init__`` from its own base class, then - ``@deprecated`` would mess up ``__init__`` inheritance when installing its - own (deprecation-emitting) ``C.__init__``). - - Parameters are the same as for `warn_deprecated`, except that *obj_type* - defaults to 'class' if decorating a class, 'attribute' if decorating a - property, and 'function' otherwise. - - Arguments: - since : str - The release at which this API became deprecated. - message : str, optional - Override the default deprecation message. The %(since)s, - %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, - and %(removal)s format specifiers will be replaced by the - values of the respective arguments passed to this function. - name : str, optional - The name of the deprecated object. - alternative : str, optional - An alternative API that the user may use in place of the - deprecated API. The deprecation warning will tell the user - about this alternative if provided. - pending : bool, optional - If True, uses a PendingDeprecationWarning instead of a - DeprecationWarning. Cannot be used together with removal. - obj_type : str, optional - The object type being deprecated. - addendum : str, optional - Additional text appended directly to the final message. - removal : str, optional - The expected removal version. With the default (an empty - string), a removal version is automatically computed from - since. Set to other Falsy values to not schedule a removal - date. Cannot be used together with pending. - - Examples - -------- - - .. code-block:: python - - @deprecated('1.4.0') - def the_function_to_deprecate(): - pass - """ - - def deprecate( - obj: T, - *, - _obj_type: str = obj_type, - _name: str = name, - _message: str = message, - _alternative: str = alternative, - _pending: bool = pending, - _addendum: str = addendum, - ) -> T: - """Implementation of the decorator returned by `deprecated`.""" - if isinstance(obj, type): - if not _obj_type: - _obj_type = "class" - wrapped = obj.__init__ # type: ignore - _name = _name or obj.__name__ - old_doc = obj.__doc__ - - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: - """Finalize the deprecation of a class.""" - try: - obj.__doc__ = new_doc - except AttributeError: # Can't set on some extension objects. - pass - obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] - wrapper - ) - return obj - - elif isinstance(obj, property): - if not _obj_type: - _obj_type = "attribute" - wrapped = None - _name = _name or obj.fget.__name__ - old_doc = obj.__doc__ - - class _deprecated_property(type(obj)): # type: ignore - """A deprecated property.""" - - def __get__(self, instance, owner=None): # type: ignore - if instance is not None or owner is not None: - emit_warning() - return super().__get__(instance, owner) - - def __set__(self, instance, value): # type: ignore - if instance is not None: - emit_warning() - return super().__set__(instance, value) - - def __delete__(self, instance): # type: ignore - if instance is not None: - emit_warning() - return super().__delete__(instance) - - def __set_name__(self, owner, set_name): # type: ignore - nonlocal _name - if _name == "": - _name = set_name - - def finalize(_: Any, new_doc: str) -> Any: # type: ignore - """Finalize the property.""" - return _deprecated_property( - fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc - ) - - else: - if not _obj_type: - _obj_type = "function" - wrapped = obj - _name = _name or obj.__name__ # type: ignore - old_doc = wrapped.__doc__ - - def finalize( # type: ignore - wrapper: Callable[..., Any], new_doc: str - ) -> T: - """Wrap the wrapped function using the wrapper and update the docstring. - - Args: - wrapper: The wrapper function. - new_doc: The new docstring. - - Returns: - The wrapped function. - """ - wrapper = functools.wraps(wrapped)(wrapper) - wrapper.__doc__ = new_doc - return wrapper - - def emit_warning() -> None: - """Emit the warning.""" - warn_deprecated( - since, - message=_message, - name=_name, - alternative=_alternative, - pending=_pending, - obj_type=_obj_type, - addendum=_addendum, - removal=removal, - ) - - def warning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any: - """Wrapper for the original wrapped callable that emits a warning. - - Args: - *args: The positional arguments to the function. - **kwargs: The keyword arguments to the function. - - Returns: - The return value of the function being wrapped. - """ - emit_warning() - return wrapped(*args, **kwargs) - - old_doc = inspect.cleandoc(old_doc or "").strip("\n") - - if not old_doc: - new_doc = "[*Deprecated*]" - else: - new_doc = f"[*Deprecated*] {old_doc}" - - # Modify the docstring to include a deprecation notice. - notes_header = "\nNotes\n-----" - components = [ - message, - f"Use {alternative} instead." if alternative else "", - addendum, - ] - details = " ".join([component.strip() for component in components if component]) - new_doc += ( - f"[*Deprecated*] {old_doc}\n" - f"{notes_header if notes_header not in old_doc else ''}\n" - f".. deprecated:: {since}\n" - f" {details}" - ) - - return finalize(warning_emitting_wrapper, new_doc) - - return deprecate - - -@contextlib.contextmanager -def suppress_langchain_deprecation_warning() -> Generator[None, None, None]: - """Context manager to suppress LangChainDeprecationWarning.""" - with warnings.catch_warnings(): - warnings.simplefilter("ignore", LangChainDeprecationWarning) - warnings.simplefilter("ignore", LangChainPendingDeprecationWarning) - yield - - -def warn_deprecated( - since: str, - *, - message: str = "", - name: str = "", - alternative: str = "", - pending: bool = False, - obj_type: str = "", - addendum: str = "", - removal: str = "", -) -> None: - """Display a standardized deprecation. - - Arguments: - since : str - The release at which this API became deprecated. - message : str, optional - Override the default deprecation message. The %(since)s, - %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, - and %(removal)s format specifiers will be replaced by the - values of the respective arguments passed to this function. - name : str, optional - The name of the deprecated object. - alternative : str, optional - An alternative API that the user may use in place of the - deprecated API. The deprecation warning will tell the user - about this alternative if provided. - pending : bool, optional - If True, uses a PendingDeprecationWarning instead of a - DeprecationWarning. Cannot be used together with removal. - obj_type : str, optional - The object type being deprecated. - addendum : str, optional - Additional text appended directly to the final message. - removal : str, optional - The expected removal version. With the default (an empty - string), a removal version is automatically computed from - since. Set to other Falsy values to not schedule a removal - date. Cannot be used together with pending. - """ - if pending and removal: - raise ValueError("A pending deprecation cannot have a scheduled removal") - - if not pending: - if not removal: - removal = f"in {removal}" if removal else "within ?? minor releases" - raise NotImplementedError( - f"Need to determine which default deprecation schedule to use. " - f"{removal}" - ) - else: - removal = f"in {removal}" - - if not message: - message = "" - - if obj_type: - message += f"The {obj_type} `{name}`" - else: - message += f"`{name}`" - - if pending: - message += " will be deprecated in a future version" - else: - message += f" was deprecated in LangChain {since}" - - if removal: - message += f" and will be removed {removal}" - - if alternative: - message += f". Use {alternative} instead." - - if addendum: - message += f" {addendum}" - - warning_cls = ( - LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning - ) - warning = warning_cls(message) - warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2) - - -def surface_langchain_deprecation_warnings() -> None: - """Unmute LangChain deprecation warnings.""" - warnings.filterwarnings( - "default", - category=LangChainPendingDeprecationWarning, - ) - - warnings.filterwarnings( - "default", - category=LangChainDeprecationWarning, - ) +__all__ = [ + "LangChainDeprecationWarning", + "LangChainPendingDeprecationWarning", + "deprecated", + "suppress_langchain_deprecation_warning", + "warn_deprecated", + "surface_langchain_deprecation_warnings", +] diff --git a/libs/langchain/langchain/_api/path.py b/libs/langchain/langchain/_api/path.py index 0589ae44956..5ee0fe817da 100644 --- a/libs/langchain/langchain/_api/path.py +++ b/libs/langchain/langchain/_api/path.py @@ -1,36 +1,3 @@ -import os -from pathlib import Path -from typing import Optional, Union +from langchain_core._api.path import as_import_path, get_relative_path -HERE = Path(__file__).parent - -# Get directory of langchain package -PACKAGE_DIR = HERE.parent -SEPARATOR = os.sep - - -def get_relative_path( - file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR -) -> str: - """Get the path of the file as a relative path to the package directory.""" - if isinstance(file, str): - file = Path(file) - return str(file.relative_to(relative_to)) - - -def as_import_path( - file: Union[Path, str], - *, - suffix: Optional[str] = None, - relative_to: Path = PACKAGE_DIR, -) -> str: - """Path of the file as a LangChain import exclude langchain top namespace.""" - if isinstance(file, str): - file = Path(file) - path = get_relative_path(file, relative_to=relative_to) - if file.is_file(): - path = path[: -len(file.suffix)] - import_path = path.replace(SEPARATOR, ".") - if suffix: - import_path += "." + suffix - return import_path +__all__ = ["get_relative_path", "as_import_path"] diff --git a/libs/langchain/langchain/adapters/openai.py b/libs/langchain/langchain/adapters/openai.py index e35dcf3e791..dad0d4e419d 100644 --- a/libs/langchain/langchain/adapters/openai.py +++ b/libs/langchain/langchain/adapters/openai.py @@ -13,10 +13,8 @@ from typing import ( overload, ) -from typing_extensions import Literal - -from langchain.schema.chat import ChatSession -from langchain.schema.messages import ( +from langchain_core.schema.chat import ChatSession +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -27,6 +25,7 @@ from langchain.schema.messages import ( SystemMessage, ToolMessage, ) +from typing_extensions import Literal async def aenumerate( diff --git a/libs/langchain/langchain/agents/__init__.py b/libs/langchain/langchain/agents/__init__.py index b1329af30f6..5af4b9ed7c7 100644 --- a/libs/langchain/langchain/agents/__init__.py +++ b/libs/langchain/langchain/agents/__init__.py @@ -31,7 +31,8 @@ Agents select and use **Tools** and **Toolkits** for actions. from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path + from langchain.agents.agent import ( Agent, AgentExecutor, diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 48de6b1ee2f..905826ccf2f 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -19,6 +19,20 @@ from typing import ( ) import yaml +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.runnables import Runnable +from langchain_core.schema import ( + AgentAction, + AgentFinish, + BaseOutputParser, + BasePromptTemplate, + OutputParserException, +) +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage +from langchain_core.utils.input import get_color_mapping from langchain.agents.agent_iterator import AgentExecutorIterator from langchain.agents.agent_types import AgentType @@ -33,22 +47,8 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import ( - AgentAction, - AgentFinish, - BaseOutputParser, - BasePromptTemplate, - OutputParserException, -) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage -from langchain.schema.runnable import Runnable from langchain.tools.base import BaseTool from langchain.utilities.asyncio import asyncio_timeout -from langchain.utils.input import get_color_mapping logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index b7b706fea34..f0da9972a31 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -18,6 +18,10 @@ from typing import ( Union, ) +from langchain_core.load.dump import dumpd +from langchain_core.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo +from langchain_core.utils.input import get_color_mapping + from langchain.callbacks.manager import ( AsyncCallbackManager, AsyncCallbackManagerForChainRun, @@ -25,11 +29,8 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.load.dump import dumpd -from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo from langchain.tools import BaseTool from langchain.utilities.asyncio import asyncio_timeout -from langchain.utils.input import get_color_mapping if TYPE_CHECKING: from langchain.agents.agent import AgentExecutor diff --git a/libs/langchain/langchain/agents/agent_toolkits/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/__init__.py index 9008e582554..324b1e2942b 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/__init__.py @@ -16,7 +16,8 @@ See [Security](https://python.langchain.com/docs/security) for more information. from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path + from langchain.agents.agent_toolkits.ainetwork.toolkit import AINetworkToolkit from langchain.agents.agent_toolkits.amadeus.toolkit import AmadeusToolkit from langchain.agents.agent_toolkits.azure_cognitive_services import ( diff --git a/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py index ba451c106ff..9cedd44aa1d 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Literal, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import root_validator from langchain.tools import BaseTool from langchain.tools.ainetwork.app import AINAppOps from langchain.tools.ainetwork.owner import AINOwnerOps diff --git a/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py index 27e5c778e70..c1dd29925b2 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, List +from langchain_core.pydantic_v1 import Field + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Field from langchain.tools import BaseTool from langchain.tools.amadeus.closest_airport import AmadeusClosestAirport from langchain.tools.amadeus.flight_search import AmadeusFlightSearch diff --git a/libs/langchain/langchain/agents/agent_toolkits/base.py b/libs/langchain/langchain/agents/agent_toolkits/base.py index 5eeccce6b02..16969eb9810 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/base.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from typing import List -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel + from langchain.tools import BaseTool diff --git a/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py b/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py index fb57fba0b2a..f68a0a05817 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py +++ b/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py @@ -1,5 +1,10 @@ from typing import Any, List, Optional +from langchain_core.prompts.chat import MessagesPlaceholder +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.memory import BaseMemory +from langchain_core.schema.messages import SystemMessage + from langchain.agents.agent import AgentExecutor from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( AgentTokenBufferMemory, @@ -7,10 +12,6 @@ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.chat_models.openai import ChatOpenAI from langchain.memory.token_buffer import ConversationTokenBufferMemory -from langchain.prompts.chat import MessagesPlaceholder -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.memory import BaseMemory -from langchain.schema.messages import SystemMessage from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/agent_toolkits/csv/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/csv/__init__.py index 7af077a8e98..1b4899a2d57 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/csv/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/csv/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path def __getattr__(name: str) -> Any: diff --git a/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py index 6e4c59b49ec..551e455a0fa 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import List, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import root_validator from langchain.tools import BaseTool from langchain.tools.file_management.copy import CopyFileTool from langchain.tools.file_management.delete import DeleteFileTool diff --git a/libs/langchain/langchain/agents/agent_toolkits/gmail/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/gmail/toolkit.py index 34120cc284a..9dfe8e84352 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/gmail/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/gmail/toolkit.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, List +from langchain_core.pydantic_v1 import Field + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Field from langchain.tools import BaseTool from langchain.tools.gmail.create_draft import GmailCreateDraft from langchain.tools.gmail.get_message import GmailGetMessage diff --git a/libs/langchain/langchain/agents/agent_toolkits/json/base.py b/libs/langchain/langchain/agents/agent_toolkits/json/base.py index b2836523a72..4d1d7c1cee2 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/json/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/json/base.py @@ -1,6 +1,8 @@ """Json agent.""" from typing import Any, Dict, List, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit @@ -8,7 +10,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.schema.language_model import BaseLanguageModel def create_json_agent( diff --git a/libs/langchain/langchain/agents/agent_toolkits/nla/tool.py b/libs/langchain/langchain/agents/agent_toolkits/nla/tool.py index 99b3c5a7082..232bb673a0b 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/nla/tool.py +++ b/libs/langchain/langchain/agents/agent_toolkits/nla/tool.py @@ -3,9 +3,10 @@ from typing import Any, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.tools import Tool from langchain.chains.api.openapi.chain import OpenAPIEndpointChain -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec from langchain.utilities.requests import Requests diff --git a/libs/langchain/langchain/agents/agent_toolkits/nla/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/nla/toolkit.py index b0134ce3ab3..9e868e84c0c 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/nla/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/nla/toolkit.py @@ -2,10 +2,11 @@ from __future__ import annotations from typing import Any, List, Optional, Sequence +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.nla.tool import NLATool -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec from langchain.tools.plugin import AIPlugin diff --git a/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py index 48bff436775..d40423fd972 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, List +from langchain_core.pydantic_v1 import Field + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Field from langchain.tools import BaseTool from langchain.tools.office365.create_draft_message import O365CreateDraftMessage from langchain.tools.office365.events_search import O365SearchEvents diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py index 33014a56d7a..cb86dcb84d2 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py @@ -1,6 +1,8 @@ """OpenAPI spec agent.""" from typing import Any, Dict, List, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.openapi.prompt import ( OPENAPI_PREFIX, @@ -11,7 +13,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.schema.language_model import BaseLanguageModel def create_openapi_agent( diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/planner.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/planner.py index 3dcaadedd66..4c20c496f6f 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/planner.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/planner.py @@ -5,6 +5,10 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional import yaml +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.openapi.planner_prompt import ( @@ -33,10 +37,6 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI from langchain.memory import ReadOnlySharedMemory -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.requests.tool import BaseRequestsTool from langchain.utilities.requests import RequestsWrapper diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/planner_prompt.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/planner_prompt.py index c7263e68bee..ec99e823e06 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/planner_prompt.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/planner_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API. diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/toolkit.py index 49d50e01f02..ab5650651b1 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/toolkit.py @@ -3,13 +3,14 @@ from __future__ import annotations from typing import Any, List +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.json.base import create_json_agent from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION from langchain.agents.tools import Tool -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.json.tool import JsonSpec from langchain.tools.requests.tool import ( diff --git a/libs/langchain/langchain/agents/agent_toolkits/pandas/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/pandas/__init__.py index 7af077a8e98..1b4899a2d57 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/pandas/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/pandas/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path def __getattr__(name: str) -> Any: diff --git a/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py index 44e8ea5cb76..b2410d95f32 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Type, cast +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Extra, root_validator from langchain.tools.base import BaseTool from langchain.tools.playwright.base import ( BaseBrowserTool, diff --git a/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py b/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py index c1aa162cf7c..2a638e3ca17 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py @@ -1,6 +1,8 @@ """Power BI agent.""" from typing import Any, Dict, List, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents import AgentExecutor from langchain.agents.agent_toolkits.powerbi.prompt import ( POWERBI_PREFIX, @@ -11,7 +13,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.schema.language_model import BaseLanguageModel from langchain.utilities.powerbi import PowerBIDataset diff --git a/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py index 2bbd1313c66..ebcff6aa498 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py @@ -1,18 +1,19 @@ """Toolkit for interacting with a Power BI dataset.""" from typing import List, Optional, Union +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.chat_models.base import BaseChatModel -from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.powerbi.prompt import ( QUESTION_TO_QUERY_BASE, diff --git a/libs/langchain/langchain/agents/agent_toolkits/python/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/python/__init__.py index 7af077a8e98..1b4899a2d57 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/python/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/python/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path def __getattr__(name: str) -> Any: diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/spark/__init__.py index 7af077a8e98..1b4899a2d57 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path def __getattr__(name: str) -> Any: diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py index 0c4238c32a5..5a0071e87c7 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py @@ -1,6 +1,8 @@ """Spark SQL agent.""" from typing import Any, Dict, List, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit @@ -8,7 +10,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.chains.llm import LLMChain -from langchain.schema.language_model import BaseLanguageModel def create_spark_sql_agent( diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py index 3b7e1641c72..280f38e9eab 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py @@ -1,9 +1,10 @@ """Toolkit for interacting with Spark SQL.""" from typing import List +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.spark_sql.tool import ( InfoSparkSQLTool, diff --git a/libs/langchain/langchain/agents/agent_toolkits/sql/base.py b/libs/langchain/langchain/agents/agent_toolkits/sql/base.py index 21cf249e300..5d5064f3519 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/sql/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/sql/base.py @@ -1,6 +1,14 @@ """SQL agent.""" from typing import Any, Dict, List, Optional, Sequence +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import AIMessage, SystemMessage + from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent from langchain.agents.agent_toolkits.sql.prompt import ( SQL_FUNCTIONS_SUFFIX, @@ -14,13 +22,6 @@ from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, -) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import AIMessage, SystemMessage from langchain.tools import BaseTool diff --git a/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py index 32b68840a75..f1a217eaa68 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py @@ -1,9 +1,10 @@ """Toolkit for interacting with an SQL database.""" from typing import List +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.sql_database.tool import ( InfoSQLDatabaseTool, diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/base.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/base.py index 73b04492379..66d4f2c0faa 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/base.py @@ -1,6 +1,8 @@ """VectorStore agent.""" from typing import Any, Dict, Optional +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.vectorstore.prompt import PREFIX, ROUTER_PREFIX from langchain.agents.agent_toolkits.vectorstore.toolkit import ( @@ -10,7 +12,6 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import ( from langchain.agents.mrkl.base import ZeroShotAgent from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.schema.language_model import BaseLanguageModel def create_vectorstore_agent( diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index 724a9d45588..d820b79a97c 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -1,11 +1,12 @@ """Toolkit for interacting with a vector store.""" from typing import List +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.vectorstore import VectorStore + from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.llms.openai import OpenAI -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore from langchain.tools import BaseTool from langchain.tools.vectorstore.tool import ( VectorStoreQATool, diff --git a/libs/langchain/langchain/agents/agent_toolkits/xorbits/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/xorbits/__init__.py index 7af077a8e98..1b4899a2d57 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/xorbits/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/xorbits/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any -from langchain._api.path import as_import_path +from langchain_core._api.path import as_import_path def __getattr__(name: str) -> Any: diff --git a/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py index 5c5ae5ad5f5..49588cbc38d 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py @@ -1,7 +1,8 @@ """[DEPRECATED] Zapier Toolkit.""" from typing import List -from langchain._api import warn_deprecated +from langchain_core._api import warn_deprecated + from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.tools import BaseTool from langchain.tools.zapier.tool import ZapierNLARunAction diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index dfe97b222f7..399d003ccf4 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -1,5 +1,14 @@ from typing import Any, List, Optional, Sequence, Tuple +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import AgentAction, BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser from langchain.agents.chat.prompt import ( @@ -11,14 +20,6 @@ from langchain.agents.chat.prompt import ( from langchain.agents.utils import validate_tools_single_input from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/chat/output_parser.py b/libs/langchain/langchain/agents/chat/output_parser.py index 6ef7a155c2d..565ab8c519b 100644 --- a/libs/langchain/langchain/agents/chat/output_parser.py +++ b/libs/langchain/langchain/agents/chat/output_parser.py @@ -2,9 +2,10 @@ import json import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish, OutputParserException FINAL_ANSWER_ACTION = "Final Answer:" diff --git a/libs/langchain/langchain/agents/conversational/base.py b/libs/langchain/langchain/agents/conversational/base.py index ab5f041ec9b..cfae34b51df 100644 --- a/libs/langchain/langchain/agents/conversational/base.py +++ b/libs/langchain/langchain/agents/conversational/base.py @@ -3,6 +3,10 @@ from __future__ import annotations from typing import Any, List, Optional, Sequence +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.conversational.output_parser import ConvoOutputParser @@ -10,9 +14,6 @@ from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, from langchain.agents.utils import validate_tools_single_input from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/conversational/output_parser.py b/libs/langchain/langchain/agents/conversational/output_parser.py index c07297db4fa..d57d9922ee3 100644 --- a/libs/langchain/langchain/agents/conversational/output_parser.py +++ b/libs/langchain/langchain/agents/conversational/output_parser.py @@ -1,9 +1,10 @@ import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish, OutputParserException class ConvoOutputParser(AgentOutputParser): diff --git a/libs/langchain/langchain/agents/conversational_chat/base.py b/libs/langchain/langchain/agents/conversational_chat/base.py index 5b99593b68e..901f44223f2 100644 --- a/libs/langchain/langchain/agents/conversational_chat/base.py +++ b/libs/langchain/langchain/agents/conversational_chat/base.py @@ -3,6 +3,17 @@ from __future__ import annotations from typing import Any, List, Optional, Sequence, Tuple +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import AgentAction, BaseOutputParser, BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage + from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.conversational_chat.output_parser import ConvoOutputParser from langchain.agents.conversational_chat.prompt import ( @@ -13,16 +24,6 @@ from langchain.agents.conversational_chat.prompt import ( from langchain.agents.utils import validate_tools_single_input from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - SystemMessagePromptTemplate, -) -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/conversational_chat/output_parser.py b/libs/langchain/langchain/agents/conversational_chat/output_parser.py index 6a200af31ca..a18e9e9a911 100644 --- a/libs/langchain/langchain/agents/conversational_chat/output_parser.py +++ b/libs/langchain/langchain/agents/conversational_chat/output_parser.py @@ -2,10 +2,11 @@ from __future__ import annotations from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents import AgentOutputParser from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS from langchain.output_parsers.json import parse_json_markdown -from langchain.schema import AgentAction, AgentFinish, OutputParserException # Define a class that parses output for conversational agents diff --git a/libs/langchain/langchain/agents/format_scratchpad/log.py b/libs/langchain/langchain/agents/format_scratchpad/log.py index 810556b2c01..06eb965bef8 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/log.py +++ b/libs/langchain/langchain/agents/format_scratchpad/log.py @@ -1,6 +1,6 @@ from typing import List, Tuple -from langchain.schema.agent import AgentAction +from langchain_core.schema.agent import AgentAction def format_log_to_str( diff --git a/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py index c370d3a987e..bf39e75ba11 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py +++ b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py @@ -1,7 +1,7 @@ from typing import List, Tuple -from langchain.schema.agent import AgentAction -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.schema.agent import AgentAction +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage def format_log_to_messages( diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py index 7068dfe5da3..16aa3a23db1 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py @@ -1,8 +1,8 @@ import json from typing import List, Sequence, Tuple -from langchain.schema.agent import AgentAction, AgentActionMessageLog -from langchain.schema.messages import AIMessage, BaseMessage, FunctionMessage +from langchain_core.schema.agent import AgentAction, AgentActionMessageLog +from langchain_core.schema.messages import AIMessage, BaseMessage, FunctionMessage def _convert_agent_action_to_messages( diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py b/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py index b8481631fcc..b2f490c70cc 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py @@ -1,14 +1,15 @@ import json from typing import List, Sequence, Tuple -from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction -from langchain.schema.agent import AgentAction -from langchain.schema.messages import ( +from langchain_core.schema.agent import AgentAction +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ToolMessage, ) +from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction + def _create_tool_message( agent_action: OpenAIToolAgentAction, observation: str diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py index 7e2539dae33..a7db742c866 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/xml.py +++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py @@ -1,6 +1,6 @@ from typing import List, Tuple -from langchain.schema.agent import AgentAction +from langchain_core.schema.agent import AgentAction def format_xml( diff --git a/libs/langchain/langchain/agents/initialize.py b/libs/langchain/langchain/agents/initialize.py index 378d2c9d116..c114d5733cc 100644 --- a/libs/langchain/langchain/agents/initialize.py +++ b/libs/langchain/langchain/agents/initialize.py @@ -1,11 +1,12 @@ """Load agent.""" from typing import Any, Optional, Sequence +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentExecutor from langchain.agents.agent_types import AgentType from langchain.agents.loading import AGENT_TO_CLASS, load_agent from langchain.callbacks.base import BaseCallbackManager -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/load_tools.py b/libs/langchain/langchain/agents/load_tools.py index 60377c137d0..2e7254343a3 100644 --- a/libs/langchain/langchain/agents/load_tools.py +++ b/libs/langchain/langchain/agents/load_tools.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Callable, Tuple from mypy_extensions import Arg, KwArg from langchain.agents.tools import Tool -from langchain.schema.language_model import BaseLanguageModel +from langchain_core.schema.language_model import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py index 32e882c94ad..8915b45dea0 100644 --- a/libs/langchain/langchain/agents/loading.py +++ b/libs/langchain/langchain/agents/loading.py @@ -5,12 +5,12 @@ from pathlib import Path from typing import Any, List, Optional, Union import yaml +from langchain_core.schema.language_model import BaseLanguageModel from langchain.agents.agent import BaseMultiActionAgent, BaseSingleActionAgent from langchain.agents.tools import Tool from langchain.agents.types import AGENT_TO_CLASS from langchain.chains.loading import load_chain, load_chain_from_config -from langchain.schema.language_model import BaseLanguageModel from langchain.utilities.loading import try_load_from_hub logger = logging.getLogger(__file__) diff --git a/libs/langchain/langchain/agents/mrkl/base.py b/libs/langchain/langchain/agents/mrkl/base.py index 177291c03f1..7fef9acdf4f 100644 --- a/libs/langchain/langchain/agents/mrkl/base.py +++ b/libs/langchain/langchain/agents/mrkl/base.py @@ -3,6 +3,10 @@ from __future__ import annotations from typing import Any, Callable, List, NamedTuple, Optional, Sequence +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.mrkl.output_parser import MRKLOutputParser @@ -11,9 +15,6 @@ from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/mrkl/output_parser.py b/libs/langchain/langchain/agents/mrkl/output_parser.py index 60b1f58b465..058ab512b35 100644 --- a/libs/langchain/langchain/agents/mrkl/output_parser.py +++ b/libs/langchain/langchain/agents/mrkl/output_parser.py @@ -1,9 +1,10 @@ import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish, OutputParserException FINAL_ANSWER_ACTION = "Final Answer:" MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = ( diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 66323125c0f..19edcacec90 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -5,11 +5,12 @@ from json import JSONDecodeError from time import sleep from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from langchain_core.load import dumpd +from langchain_core.pydantic_v1 import Field +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.schema.agent import AgentAction, AgentFinish + from langchain.callbacks.manager import CallbackManager -from langchain.load import dumpd -from langchain.pydantic_v1 import Field -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.runnable import RunnableConfig, RunnableSerializable from langchain.tools.base import BaseTool from langchain.tools.render import format_tool_to_openai_tool @@ -102,7 +103,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): from langchain_experimental.openai_assistant import OpenAIAssistantRunnable from langchain.agents import AgentExecutor - from langchain.schema.agent import AgentFinish + from langchain_core.schema.agent import AgentFinish from langchain.tools import E2BDataAnalysisTool diff --git a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py index 75a1e53f295..a94e1555f09 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py @@ -1,12 +1,13 @@ """Memory used to save agent output AND intermediate steps.""" from typing import Any, Dict, List +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.agents.format_scratchpad.openai_functions import ( format_to_openai_function_messages, ) from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, get_buffer_string class AgentTokenBufferMemory(BaseChatMemory): diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index e102562e57f..cfc0100831b 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -1,6 +1,24 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" from typing import Any, List, Optional, Sequence, Tuple, Union +from langchain_core.prompts.chat import ( + BaseMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ( + AgentAction, + AgentFinish, + BasePromptTemplate, +) +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import ( + BaseMessage, + SystemMessage, +) + from langchain.agents import BaseSingleActionAgent from langchain.agents.format_scratchpad.openai_functions import ( format_to_openai_function_messages, @@ -11,23 +29,6 @@ from langchain.agents.output_parsers.openai_functions import ( from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI -from langchain.prompts.chat import ( - BaseMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, -) -from langchain.pydantic_v1 import root_validator -from langchain.schema import ( - AgentAction, - AgentFinish, - BasePromptTemplate, -) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import ( - BaseMessage, - SystemMessage, -) from langchain.tools.base import BaseTool from langchain.tools.render import format_tool_to_openai_function diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index f9c21e10624..2523ce44e2c 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -3,6 +3,27 @@ import json from json import JSONDecodeError from typing import Any, List, Optional, Sequence, Tuple, Union +from langchain_core.prompts.chat import ( + BaseMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ( + AgentAction, + AgentFinish, + BasePromptTemplate, + OutputParserException, +) +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import ( + AIMessage, + BaseMessage, + SystemMessage, +) + from langchain.agents import BaseMultiActionAgent from langchain.agents.format_scratchpad.openai_functions import ( format_to_openai_function_messages, @@ -10,26 +31,6 @@ from langchain.agents.format_scratchpad.openai_functions import ( from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chat_models.openai import ChatOpenAI -from langchain.prompts.chat import ( - BaseMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, -) -from langchain.pydantic_v1 import root_validator -from langchain.schema import ( - AgentAction, - AgentFinish, - BasePromptTemplate, - OutputParserException, -) -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import ( - AIMessage, - BaseMessage, - SystemMessage, -) from langchain.tools import BaseTool # For backwards compatibility diff --git a/libs/langchain/langchain/agents/output_parsers/json.py b/libs/langchain/langchain/agents/output_parsers/json.py index c237cc819d1..739dd2482c1 100644 --- a/libs/langchain/langchain/agents/output_parsers/json.py +++ b/libs/langchain/langchain/agents/output_parsers/json.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.output_parsers.json import parse_json_markdown -from langchain.schema import AgentAction, AgentFinish, OutputParserException logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index c6b4aadb718..ee83eb337b4 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -3,18 +3,19 @@ import json from json import JSONDecodeError from typing import List, Union -from langchain.agents.agent import AgentOutputParser -from langchain.schema import ( +from langchain_core.schema import ( AgentAction, AgentFinish, OutputParserException, ) -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.messages import ( +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ) -from langchain.schema.output import ChatGeneration, Generation +from langchain_core.schema.output import ChatGeneration, Generation + +from langchain.agents.agent import AgentOutputParser class OpenAIFunctionsAgentOutputParser(AgentOutputParser): diff --git a/libs/langchain/langchain/agents/output_parsers/openai_tools.py b/libs/langchain/langchain/agents/output_parsers/openai_tools.py index f14fd9e3a96..a92d6f6b6ca 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_tools.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_tools.py @@ -3,18 +3,19 @@ import json from json import JSONDecodeError from typing import List, Union -from langchain.agents.agent import MultiActionAgentOutputParser -from langchain.schema import ( +from langchain_core.schema import ( AgentAction, AgentFinish, OutputParserException, ) -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.messages import ( +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ) -from langchain.schema.output import ChatGeneration, Generation +from langchain_core.schema.output import ChatGeneration, Generation + +from langchain.agents.agent import MultiActionAgentOutputParser class OpenAIToolAgentAction(AgentActionMessageLog): diff --git a/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py b/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py index 8878d5aee6b..a281388d708 100644 --- a/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py +++ b/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py @@ -2,9 +2,10 @@ import json import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish, OutputParserException FINAL_ANSWER_ACTION = "Final Answer:" diff --git a/libs/langchain/langchain/agents/output_parsers/react_single_input.py b/libs/langchain/langchain/agents/output_parsers/react_single_input.py index 573fa7f4486..9c201d98c27 100644 --- a/libs/langchain/langchain/agents/output_parsers/react_single_input.py +++ b/libs/langchain/langchain/agents/output_parsers/react_single_input.py @@ -1,9 +1,10 @@ import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS -from langchain.schema import AgentAction, AgentFinish, OutputParserException FINAL_ANSWER_ACTION = "Final Answer:" MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = ( diff --git a/libs/langchain/langchain/agents/output_parsers/self_ask.py b/libs/langchain/langchain/agents/output_parsers/self_ask.py index 6423187b77f..ae665a58a26 100644 --- a/libs/langchain/langchain/agents/output_parsers/self_ask.py +++ b/libs/langchain/langchain/agents/output_parsers/self_ask.py @@ -1,7 +1,8 @@ from typing import Sequence, Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser -from langchain.schema import AgentAction, AgentFinish, OutputParserException class SelfAskOutputParser(AgentOutputParser): diff --git a/libs/langchain/langchain/agents/output_parsers/xml.py b/libs/langchain/langchain/agents/output_parsers/xml.py index 20ff928c7f9..15f0916db7a 100644 --- a/libs/langchain/langchain/agents/output_parsers/xml.py +++ b/libs/langchain/langchain/agents/output_parsers/xml.py @@ -1,7 +1,8 @@ from typing import Union +from langchain_core.schema import AgentAction, AgentFinish + from langchain.agents import AgentOutputParser -from langchain.schema import AgentAction, AgentFinish class XMLAgentOutputParser(AgentOutputParser): diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index 323104cd30d..34fad1f6e6d 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -1,6 +1,10 @@ """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" from typing import Any, List, Optional, Sequence +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.react.output_parser import ReActOutputParser @@ -10,9 +14,6 @@ from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/agents/react/output_parser.py b/libs/langchain/langchain/agents/react/output_parser.py index 7f3e75b3f21..fe685aeb7bc 100644 --- a/libs/langchain/langchain/agents/react/output_parser.py +++ b/libs/langchain/langchain/agents/react/output_parser.py @@ -1,8 +1,9 @@ import re from typing import Union +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException + from langchain.agents.agent import AgentOutputParser -from langchain.schema import AgentAction, AgentFinish, OutputParserException class ReActOutputParser(AgentOutputParser): diff --git a/libs/langchain/langchain/agents/react/textworld_prompt.py b/libs/langchain/langchain/agents/react/textworld_prompt.py index f01b9a6dd95..26cfd49aa29 100644 --- a/libs/langchain/langchain/agents/react/textworld_prompt.py +++ b/libs/langchain/langchain/agents/react/textworld_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate EXAMPLES = [ """Setup: You are now playing a fast paced round of TextWorld! Here is your task for diff --git a/libs/langchain/langchain/agents/react/wiki_prompt.py b/libs/langchain/langchain/agents/react/wiki_prompt.py index 866facd1fdf..9db6cf92918 100644 --- a/libs/langchain/langchain/agents/react/wiki_prompt.py +++ b/libs/langchain/langchain/agents/react/wiki_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate EXAMPLES = [ """Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? diff --git a/libs/langchain/langchain/agents/schema.py b/libs/langchain/langchain/agents/schema.py index ee01b3a0ce0..ba730f48034 100644 --- a/libs/langchain/langchain/agents/schema.py +++ b/libs/langchain/langchain/agents/schema.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Tuple -from langchain.prompts.chat import ChatPromptTemplate -from langchain.schema import AgentAction +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.schema import AgentAction class AgentScratchPadChatPromptTemplate(ChatPromptTemplate): diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index ddd77f2277e..ce2e2581748 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -1,15 +1,16 @@ """Chain that does self-ask with search.""" from typing import Any, Sequence, Union +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper diff --git a/libs/langchain/langchain/agents/self_ask_with_search/prompt.py b/libs/langchain/langchain/agents/self_ask_with_search/prompt.py index c82de28dfbe..c9154785cd9 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/prompt.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing? Are follow up questions needed here: Yes. diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index 7c6fdd06993..0f1158b4649 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -1,6 +1,15 @@ import re from typing import Any, List, Optional, Sequence, Tuple +from langchain_core.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import AgentAction, BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.structured_chat.output_parser import ( StructuredChatOutputParserWithRetries, @@ -8,14 +17,6 @@ from langchain.agents.structured_chat.output_parser import ( from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" diff --git a/libs/langchain/langchain/agents/structured_chat/output_parser.py b/libs/langchain/langchain/agents/structured_chat/output_parser.py index ab5d449bfcf..2a961d50fee 100644 --- a/libs/langchain/langchain/agents/structured_chat/output_parser.py +++ b/libs/langchain/langchain/agents/structured_chat/output_parser.py @@ -5,12 +5,13 @@ import logging import re from typing import Optional, Union +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.agents.agent import AgentOutputParser from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS from langchain.output_parsers import OutputFixingParser -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, AgentFinish, OutputParserException -from langchain.schema.language_model import BaseLanguageModel logger = logging.getLogger(__name__) @@ -79,7 +80,7 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): ) -> StructuredChatOutputParserWithRetries: if llm is not None: base_parser = base_parser or StructuredChatOutputParser() - output_fixing_parser = OutputFixingParser.from_llm( + output_fixing_parser: OutputFixingParser = OutputFixingParser.from_llm( llm=llm, parser=base_parser ) return cls(output_fixing_parser=output_fixing_parser) diff --git a/libs/langchain/langchain/agents/xml/base.py b/libs/langchain/langchain/agents/xml/base.py index 815da7f5fe9..c3b6966f38d 100644 --- a/libs/langchain/langchain/agents/xml/base.py +++ b/libs/langchain/langchain/agents/xml/base.py @@ -1,12 +1,13 @@ from typing import Any, List, Tuple, Union +from langchain_core.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate +from langchain_core.schema import AgentAction, AgentFinish + from langchain.agents.agent import BaseSingleActionAgent from langchain.agents.output_parsers.xml import XMLAgentOutputParser from langchain.agents.xml.prompt import agent_instructions from langchain.callbacks.base import Callbacks from langchain.chains.llm import LLMChain -from langchain.prompts.chat import AIMessagePromptTemplate, ChatPromptTemplate -from langchain.schema import AgentAction, AgentFinish from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/base_language.py b/libs/langchain/langchain/base_language.py index 30323070e54..9912c1a95f3 100644 --- a/libs/langchain/langchain/base_language.py +++ b/libs/langchain/langchain/base_language.py @@ -1,6 +1,6 @@ """Deprecated module for BaseLanguageModel class, kept for backwards compatibility.""" from __future__ import annotations -from langchain.schema.language_model import BaseLanguageModel +from langchain_core.schema.language_model import BaseLanguageModel __all__ = ["BaseLanguageModel"] diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index a45c4328211..46197c119a6 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -51,12 +51,13 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from langchain_core.load.dump import dumps +from langchain_core.load.load import loads +from langchain_core.schema import ChatGeneration, Generation +from langchain_core.schema.cache import RETURN_VAL_TYPE, BaseCache +from langchain_core.schema.embeddings import Embeddings + from langchain.llms.base import LLM, get_prompts -from langchain.load.dump import dumps -from langchain.load.load import loads -from langchain.schema import ChatGeneration, Generation -from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_env from langchain.vectorstores.redis import Redis as RedisVectorstore diff --git a/libs/langchain/langchain/callbacks/aim_callback.py b/libs/langchain/langchain/callbacks/aim_callback.py index 9526f34f14a..e43f44ea90c 100644 --- a/libs/langchain/langchain/callbacks/aim_callback.py +++ b/libs/langchain/langchain/callbacks/aim_callback.py @@ -1,8 +1,9 @@ from copy import deepcopy from typing import Any, Dict, List, Optional +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult def import_aim() -> Any: diff --git a/libs/langchain/langchain/callbacks/argilla_callback.py b/libs/langchain/langchain/callbacks/argilla_callback.py index eb2d2da18e6..e02ae12cc5f 100644 --- a/libs/langchain/langchain/callbacks/argilla_callback.py +++ b/libs/langchain/langchain/callbacks/argilla_callback.py @@ -2,10 +2,10 @@ import os import warnings from typing import Any, Dict, List, Optional +from langchain_core.schema import AgentAction, AgentFinish, LLMResult from packaging.version import parse from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult class ArgillaCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/arize_callback.py b/libs/langchain/langchain/callbacks/arize_callback.py index a57de3a905d..191e9a1dfa7 100644 --- a/libs/langchain/langchain/callbacks/arize_callback.py +++ b/libs/langchain/langchain/callbacks/arize_callback.py @@ -1,9 +1,10 @@ from datetime import datetime from typing import Any, Dict, List, Optional +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import import_pandas -from langchain.schema import AgentAction, AgentFinish, LLMResult class ArizeCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/arthur_callback.py b/libs/langchain/langchain/callbacks/arthur_callback.py index 5584175b7be..f1d5d39b601 100644 --- a/libs/langchain/langchain/callbacks/arthur_callback.py +++ b/libs/langchain/langchain/callbacks/arthur_callback.py @@ -9,9 +9,9 @@ from time import time from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional import numpy as np +from langchain_core.schema import AgentAction, AgentFinish, LLMResult from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult if TYPE_CHECKING: import arthurai diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 151235af2ef..9b9d189d3f6 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -1,7 +1,7 @@ """Base callback handler that can be used to handle callbacks in langchain.""" from __future__ import annotations -from langchain.schema.callbacks.base import ( +from langchain_core.callbacks.base import ( AsyncCallbackHandler, BaseCallbackHandler, BaseCallbackManager, diff --git a/libs/langchain/langchain/callbacks/clearml_callback.py b/libs/langchain/langchain/callbacks/clearml_callback.py index daa495693bc..d3ea80ff4c8 100644 --- a/libs/langchain/langchain/callbacks/clearml_callback.py +++ b/libs/langchain/langchain/callbacks/clearml_callback.py @@ -5,6 +5,8 @@ from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( BaseMetadataCallbackHandler, @@ -15,7 +17,6 @@ from langchain.callbacks.utils import ( import_textstat, load_json, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult if TYPE_CHECKING: import pandas as pd diff --git a/libs/langchain/langchain/callbacks/comet_ml_callback.py b/libs/langchain/langchain/callbacks/comet_ml_callback.py index 1e8aabb1c16..d38466524e0 100644 --- a/libs/langchain/langchain/callbacks/comet_ml_callback.py +++ b/libs/langchain/langchain/callbacks/comet_ml_callback.py @@ -3,6 +3,8 @@ from copy import deepcopy from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence +from langchain_core.schema import AgentAction, AgentFinish, Generation, LLMResult + import langchain from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( @@ -12,7 +14,6 @@ from langchain.callbacks.utils import ( import_spacy, import_textstat, ) -from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult LANGCHAIN_MODEL_NAME = "langchain-model" diff --git a/libs/langchain/langchain/callbacks/confident_callback.py b/libs/langchain/langchain/callbacks/confident_callback.py index 9d8f494c938..3cba3de313c 100644 --- a/libs/langchain/langchain/callbacks/confident_callback.py +++ b/libs/langchain/langchain/callbacks/confident_callback.py @@ -4,7 +4,7 @@ import warnings from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain_core.schema import AgentAction, AgentFinish, LLMResult class DeepEvalCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/context_callback.py b/libs/langchain/langchain/callbacks/context_callback.py index 6306763114e..550341ab116 100644 --- a/libs/langchain/langchain/callbacks/context_callback.py +++ b/libs/langchain/langchain/callbacks/context_callback.py @@ -3,12 +3,13 @@ import os from typing import Any, Dict, List from uuid import UUID -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import ( +from langchain_core.schema import ( BaseMessage, LLMResult, ) +from langchain.callbacks.base import BaseCallbackHandler + def import_context() -> Any: """Import the `getcontext` package.""" diff --git a/libs/langchain/langchain/callbacks/file.py b/libs/langchain/langchain/callbacks/file.py index d5dff489008..8e386c96be7 100644 --- a/libs/langchain/langchain/callbacks/file.py +++ b/libs/langchain/langchain/callbacks/file.py @@ -1,9 +1,10 @@ """Callback Handler that writes to a file.""" from typing import Any, Dict, Optional, TextIO, cast +from langchain_core.schema import AgentAction, AgentFinish +from langchain_core.utils.input import print_text + from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish -from langchain.utils.input import print_text class FileCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/flyte_callback.py b/libs/langchain/langchain/callbacks/flyte_callback.py index fbdf2810e5b..c34ea63d923 100644 --- a/libs/langchain/langchain/callbacks/flyte_callback.py +++ b/libs/langchain/langchain/callbacks/flyte_callback.py @@ -5,6 +5,8 @@ import logging from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( BaseMetadataCallbackHandler, @@ -13,7 +15,6 @@ from langchain.callbacks.utils import ( import_spacy, import_textstat, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult if TYPE_CHECKING: import flytekit diff --git a/libs/langchain/langchain/callbacks/infino_callback.py b/libs/langchain/langchain/callbacks/infino_callback.py index 2d35850554d..815fa3d2259 100644 --- a/libs/langchain/langchain/callbacks/infino_callback.py +++ b/libs/langchain/langchain/callbacks/infino_callback.py @@ -1,9 +1,10 @@ import time from typing import Any, Dict, List, Optional, cast +from langchain_core.schema import AgentAction, AgentFinish, LLMResult +from langchain_core.schema.messages import BaseMessage + from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.schema.messages import BaseMessage def import_infino() -> Any: diff --git a/libs/langchain/langchain/callbacks/labelstudio_callback.py b/libs/langchain/langchain/callbacks/labelstudio_callback.py index de3c669b00a..303a4315bdc 100644 --- a/libs/langchain/langchain/callbacks/labelstudio_callback.py +++ b/libs/langchain/langchain/callbacks/labelstudio_callback.py @@ -5,8 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import ( +from langchain_core.schema import ( AgentAction, AgentFinish, BaseMessage, @@ -15,6 +14,8 @@ from langchain.schema import ( LLMResult, ) +from langchain.callbacks.base import BaseCallbackHandler + class LabelStudioMode(Enum): """Label Studio mode enumerator.""" diff --git a/libs/langchain/langchain/callbacks/llmonitor_callback.py b/libs/langchain/langchain/callbacks/llmonitor_callback.py index 886e583f777..8202affa9da 100644 --- a/libs/langchain/langchain/callbacks/llmonitor_callback.py +++ b/libs/langchain/langchain/callbacks/llmonitor_callback.py @@ -8,12 +8,12 @@ from typing import Any, Dict, List, Union, cast from uuid import UUID import requests +from langchain_core.schema.agent import AgentAction, AgentFinish +from langchain_core.schema.messages import BaseMessage +from langchain_core.schema.output import LLMResult from packaging.version import parse from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.messages import BaseMessage -from langchain.schema.output import LLMResult logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 317274fe03f..bcd7bbd9610 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -8,9 +8,7 @@ from typing import ( Optional, ) -from langchain.callbacks.openai_info import OpenAICallbackHandler -from langchain.callbacks.tracers.wandb import WandbTracer -from langchain.schema.callbacks.manager import ( +from langchain_core.callbacks.manager import ( AsyncCallbackManager, AsyncCallbackManagerForChainGroup, AsyncCallbackManagerForChainRun, @@ -40,6 +38,9 @@ from langchain.schema.callbacks.manager import ( tracing_v2_enabled, ) +from langchain.callbacks.openai_info import OpenAICallbackHandler +from langchain.callbacks.tracers.wandb import WandbTracer + logger = logging.getLogger(__name__) openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( diff --git a/libs/langchain/langchain/callbacks/mlflow_callback.py b/libs/langchain/langchain/callbacks/mlflow_callback.py index 9fc48641cdb..881d6745470 100644 --- a/libs/langchain/langchain/callbacks/mlflow_callback.py +++ b/libs/langchain/langchain/callbacks/mlflow_callback.py @@ -7,6 +7,8 @@ from copy import deepcopy from pathlib import Path from typing import Any, Dict, List, Optional, Union +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( BaseMetadataCallbackHandler, @@ -16,7 +18,6 @@ from langchain.callbacks.utils import ( import_spacy, import_textstat, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/callbacks/openai_info.py b/libs/langchain/langchain/callbacks/openai_info.py index f9d49e41882..19d40b94957 100644 --- a/libs/langchain/langchain/callbacks/openai_info.py +++ b/libs/langchain/langchain/callbacks/openai_info.py @@ -1,8 +1,9 @@ """Callback Handler that prints to std out.""" from typing import Any, Dict, List +from langchain_core.schema import LLMResult + from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult MODEL_COST_PER_1K_TOKENS = { # GPT-4 input diff --git a/libs/langchain/langchain/callbacks/promptlayer_callback.py b/libs/langchain/langchain/callbacks/promptlayer_callback.py index 749bb6b0064..567734a7c12 100644 --- a/libs/langchain/langchain/callbacks/promptlayer_callback.py +++ b/libs/langchain/langchain/callbacks/promptlayer_callback.py @@ -5,12 +5,11 @@ import datetime from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from uuid import UUID -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import ( +from langchain_core.schema import ( ChatGeneration, LLMResult, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, @@ -18,6 +17,8 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.callbacks.base import BaseCallbackHandler + if TYPE_CHECKING: import promptlayer diff --git a/libs/langchain/langchain/callbacks/sagemaker_callback.py b/libs/langchain/langchain/callbacks/sagemaker_callback.py index 9b532a04c15..913044fc56a 100644 --- a/libs/langchain/langchain/callbacks/sagemaker_callback.py +++ b/libs/langchain/langchain/callbacks/sagemaker_callback.py @@ -5,11 +5,12 @@ import tempfile from copy import deepcopy from typing import Any, Dict, List, Optional +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( flatten_dict, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult def save_json(data: dict, file_path: str) -> None: diff --git a/libs/langchain/langchain/callbacks/stdout.py b/libs/langchain/langchain/callbacks/stdout.py index ef3bab0618e..754e58248e4 100644 --- a/libs/langchain/langchain/callbacks/stdout.py +++ b/libs/langchain/langchain/callbacks/stdout.py @@ -1,3 +1,3 @@ -from langchain.schema.callbacks.stdout import StdOutCallbackHandler +from langchain_core.callbacks.stdout import StdOutCallbackHandler __all__ = ["StdOutCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/streaming_aiter.py b/libs/langchain/langchain/callbacks/streaming_aiter.py index e49bea629d3..46dd77b6d5c 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter.py @@ -3,8 +3,9 @@ from __future__ import annotations import asyncio from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast +from langchain_core.schema.output import LLMResult + from langchain.callbacks.base import AsyncCallbackHandler -from langchain.schema.output import LLMResult # TODO If used by two LLM runs in parallel this won't work as expected diff --git a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py index 51e85f96ae6..8e93c0e1167 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.schema import LLMResult + from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler -from langchain.schema import LLMResult DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] diff --git a/libs/langchain/langchain/callbacks/streaming_stdout.py b/libs/langchain/langchain/callbacks/streaming_stdout.py index 26be79bd192..e2a22232b57 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout.py @@ -1,4 +1,4 @@ """Callback Handler streams to stdout on new llm token.""" -from langchain.schema.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler __all__ = ["StreamingStdOutCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py b/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py index d43af17a358..b8c89eb7e5f 100644 --- a/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py +++ b/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py @@ -5,9 +5,10 @@ from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.streamlit.mutable_expander import MutableExpander -from langchain.schema import AgentAction, AgentFinish, LLMResult if TYPE_CHECKING: from streamlit.delta_generator import DeltaGenerator diff --git a/libs/langchain/langchain/callbacks/tracers/__init__.py b/libs/langchain/langchain/callbacks/tracers/__init__.py index e33002bb306..7ff9ca7f3cc 100644 --- a/libs/langchain/langchain/callbacks/tracers/__init__.py +++ b/libs/langchain/langchain/callbacks/tracers/__init__.py @@ -1,13 +1,14 @@ """Tracers that record execution of LangChain runs.""" -from langchain.callbacks.tracers.wandb import WandbTracer -from langchain.schema.callbacks.tracers.langchain import LangChainTracer -from langchain.schema.callbacks.tracers.langchain_v1 import LangChainTracerV1 -from langchain.schema.callbacks.tracers.stdout import ( +from langchain_core.callbacks.tracers.langchain import LangChainTracer +from langchain_core.callbacks.tracers.langchain_v1 import LangChainTracerV1 +from langchain_core.callbacks.tracers.stdout import ( ConsoleCallbackHandler, FunctionCallbackHandler, ) +from langchain.callbacks.tracers.wandb import WandbTracer + __all__ = [ "LangChainTracer", "LangChainTracerV1", diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index 34946ff374b..e2628a9ba18 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -1,5 +1,5 @@ """Base interfaces for tracing runs.""" -from langchain.schema.callbacks.tracers.base import BaseTracer, TracerException +from langchain_core.callbacks.tracers.base import BaseTracer, TracerException __all__ = ["BaseTracer", "TracerException"] diff --git a/libs/langchain/langchain/callbacks/tracers/evaluation.py b/libs/langchain/langchain/callbacks/tracers/evaluation.py index 8384ea6557d..0ee0dbe1b49 100644 --- a/libs/langchain/langchain/callbacks/tracers/evaluation.py +++ b/libs/langchain/langchain/callbacks/tracers/evaluation.py @@ -1,5 +1,5 @@ """A tracer that runs evaluators over completed runs.""" -from langchain.schema.callbacks.tracers.evaluation import ( +from langchain_core.callbacks.tracers.evaluation import ( EvaluatorCallbackHandler, wait_for_all_evaluators, ) diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 1cfe3ffc0a0..031b1244f54 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -1,6 +1,6 @@ """A Tracer implementation that records to LangChain endpoint.""" -from langchain.schema.callbacks.tracers.langchain import ( +from langchain_core.callbacks.tracers.langchain import ( LangChainTracer, wait_for_all_tracers, ) diff --git a/libs/langchain/langchain/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py index 056b5c4786a..7c426f3945e 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain_v1.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain_v1.py @@ -1,3 +1,3 @@ -from langchain.schema.callbacks.tracers.langchain_v1 import LangChainTracerV1 +from langchain_core.callbacks.tracers.langchain_v1 import LangChainTracerV1 __all__ = ["LangChainTracerV1"] diff --git a/libs/langchain/langchain/callbacks/tracers/log_stream.py b/libs/langchain/langchain/callbacks/tracers/log_stream.py index 6630dd6e53f..0878fa575e6 100644 --- a/libs/langchain/langchain/callbacks/tracers/log_stream.py +++ b/libs/langchain/langchain/callbacks/tracers/log_stream.py @@ -1,4 +1,4 @@ -from langchain.schema.callbacks.tracers.log_stream import ( +from langchain_core.callbacks.tracers.log_stream import ( LogEntry, LogStreamCallbackHandler, RunLog, diff --git a/libs/langchain/langchain/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/callbacks/tracers/root_listeners.py index 2eceb6db7cb..f57b31c938d 100644 --- a/libs/langchain/langchain/callbacks/tracers/root_listeners.py +++ b/libs/langchain/langchain/callbacks/tracers/root_listeners.py @@ -1,3 +1,3 @@ -from langchain.schema.callbacks.tracers.root_listeners import RootListenersTracer +from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer __all__ = ["RootListenersTracer"] diff --git a/libs/langchain/langchain/callbacks/tracers/run_collector.py b/libs/langchain/langchain/callbacks/tracers/run_collector.py index da4b7ee8d8e..1e872946631 100644 --- a/libs/langchain/langchain/callbacks/tracers/run_collector.py +++ b/libs/langchain/langchain/callbacks/tracers/run_collector.py @@ -1,3 +1,3 @@ -from langchain.schema.callbacks.tracers.run_collector import RunCollectorCallbackHandler +from langchain_core.callbacks.tracers.run_collector import RunCollectorCallbackHandler __all__ = ["RunCollectorCallbackHandler"] diff --git a/libs/langchain/langchain/callbacks/tracers/schemas.py b/libs/langchain/langchain/callbacks/tracers/schemas.py index b4445454891..824e7576895 100644 --- a/libs/langchain/langchain/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/callbacks/tracers/schemas.py @@ -1,4 +1,4 @@ -from langchain.schema.callbacks.tracers.schemas import ( +from langchain_core.callbacks.tracers.schemas import ( BaseRun, ChainRun, LLMRun, diff --git a/libs/langchain/langchain/callbacks/tracers/stdout.py b/libs/langchain/langchain/callbacks/tracers/stdout.py index 12e8a187da2..6294ada57c3 100644 --- a/libs/langchain/langchain/callbacks/tracers/stdout.py +++ b/libs/langchain/langchain/callbacks/tracers/stdout.py @@ -1,4 +1,4 @@ -from langchain.schema.callbacks.tracers.stdout import ( +from langchain_core.callbacks.tracers.stdout import ( ConsoleCallbackHandler, FunctionCallbackHandler, ) diff --git a/libs/langchain/langchain/callbacks/trubrics_callback.py b/libs/langchain/langchain/callbacks/trubrics_callback.py index 168793aeed5..9502e690fd2 100644 --- a/libs/langchain/langchain/callbacks/trubrics_callback.py +++ b/libs/langchain/langchain/callbacks/trubrics_callback.py @@ -2,9 +2,8 @@ import os from typing import Any, Dict, List, Optional from uuid import UUID -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult -from langchain.schema.messages import ( +from langchain_core.schema import LLMResult +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, @@ -13,6 +12,8 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.callbacks.base import BaseCallbackHandler + def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] diff --git a/libs/langchain/langchain/callbacks/wandb_callback.py b/libs/langchain/langchain/callbacks/wandb_callback.py index 18102b34bdd..09baec4ed0f 100644 --- a/libs/langchain/langchain/callbacks/wandb_callback.py +++ b/libs/langchain/langchain/callbacks/wandb_callback.py @@ -4,6 +4,8 @@ from copy import deepcopy from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union +from langchain_core.schema import AgentAction, AgentFinish, LLMResult + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( BaseMetadataCallbackHandler, @@ -13,7 +15,6 @@ from langchain.callbacks.utils import ( import_spacy, import_textstat, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult def import_wandb() -> Any: diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index b8d0d178711..82d129e0fb2 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -4,6 +4,10 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple from urllib.parse import urlparse +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -11,9 +15,6 @@ from langchain.callbacks.manager import ( from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.utilities.requests import TextRequestsWrapper diff --git a/libs/langchain/langchain/chains/api/openapi/chain.py b/libs/langchain/langchain/chains/api/openapi/chain.py index 6db93aefc39..9cf983f978e 100644 --- a/libs/langchain/langchain/chains/api/openapi/chain.py +++ b/libs/langchain/langchain/chains/api/openapi/chain.py @@ -4,6 +4,8 @@ from __future__ import annotations import json from typing import Any, Dict, List, NamedTuple, Optional, cast +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema.language_model import BaseLanguageModel from requests import Response from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks @@ -11,8 +13,6 @@ from langchain.chains.api.openapi.requests_chain import APIRequesterChain from langchain.chains.api.openapi.response_chain import APIResponderChain from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.openapi.utils.api_models import APIOperation from langchain.utilities.requests import Requests diff --git a/libs/langchain/langchain/chains/api/openapi/requests_chain.py b/libs/langchain/langchain/chains/api/openapi/requests_chain.py index 002cdbe5799..4e85345e4ea 100644 --- a/libs/langchain/langchain/chains/api/openapi/requests_chain.py +++ b/libs/langchain/langchain/chains/api/openapi/requests_chain.py @@ -4,11 +4,12 @@ import json import re from typing import Any +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import BaseOutputParser +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE from langchain.chains.llm import LLMChain -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseOutputParser -from langchain.schema.language_model import BaseLanguageModel class APIRequesterOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/chains/api/openapi/response_chain.py b/libs/langchain/langchain/chains/api/openapi/response_chain.py index 18b2617e578..8699d4606db 100644 --- a/libs/langchain/langchain/chains/api/openapi/response_chain.py +++ b/libs/langchain/langchain/chains/api/openapi/response_chain.py @@ -4,11 +4,12 @@ import json import re from typing import Any +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import BaseOutputParser +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE from langchain.chains.llm import LLMChain -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseOutputParser -from langchain.schema.language_model import BaseLanguageModel class APIResponderOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/chains/api/prompt.py b/libs/langchain/langchain/chains/api/prompt.py index 020ac8d1b4c..0ffc389ad3d 100644 --- a/libs/langchain/langchain/chains/api/prompt.py +++ b/libs/langchain/langchain/chains/api/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate API_URL_PROMPT_TEMPLATE = """You are given the below API Documentation: {api_docs} diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 43af5d31a16..52e5c20c8ff 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -9,6 +9,16 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union import yaml +from langchain_core.load.dump import dumpd +from langchain_core.pydantic_v1 import ( + BaseModel, + Field, + create_model, + root_validator, + validator, +) +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.schema import RUN_KEY, BaseMemory, RunInfo from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -18,16 +28,6 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.load.dump import dumpd -from langchain.pydantic_v1 import ( - BaseModel, - Field, - create_model, - root_validator, - validator, -) -from langchain.schema import RUN_KEY, BaseMemory, RunInfo -from langchain.schema.runnable import RunnableConfig, RunnableSerializable logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chains/chat_vector_db/prompts.py b/libs/langchain/langchain/chains/chat_vector_db/prompts.py index b2a2df09e3f..19f7a210386 100644 --- a/libs/langchain/langchain/chains/chat_vector_db/prompts.py +++ b/libs/langchain/langchain/chains/chat_vector_db/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index ea28e99adda..23192de2abc 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -3,14 +3,15 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Type +from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.runnables.config import RunnableConfig + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from langchain.chains.base import Chain from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseModel, Field, create_model -from langchain.schema.runnable.config import RunnableConfig from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index fcbb721e1b6..02672804e5b 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -4,13 +4,14 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple, Type +from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator +from langchain_core.runnables.config import RunnableConfig + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator -from langchain.schema.runnable.config import RunnableConfig class MapReduceDocumentsChain(BaseCombineDocumentsChain): @@ -31,7 +32,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): ReduceDocumentsChain, MapReduceDocumentsChain, ) - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI # This controls how each document will be formatted. Specifically, diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 717222ed950..f051bad2e32 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -4,13 +4,14 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast +from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator +from langchain_core.runnables.config import RunnableConfig + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.output_parsers.regex import RegexParser -from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator -from langchain.schema.runnable.config import RunnableConfig class MapRerankDocumentsChain(BaseCombineDocumentsChain): @@ -24,7 +25,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): .. code-block:: python from langchain.chains import StuffDocumentsChain, LLMChain - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI from langchain.output_parsers.regex import RegexParser diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 95081704594..ea402033f56 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -4,10 +4,11 @@ from __future__ import annotations from typing import Any, Callable, List, Optional, Protocol, Tuple +from langchain_core.pydantic_v1 import Extra + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import Extra class CombineDocsProtocol(Protocol): @@ -144,7 +145,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): from langchain.chains import ( StuffDocumentsChain, LLMChain, ReduceDocumentsChain ) - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI # This controls how each document will be formatted. Specifically, diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index 4cdfa2ad77d..1c3afbde5c4 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -4,15 +4,16 @@ from __future__ import annotations from typing import Any, Dict, List, Tuple +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema import BasePromptTemplate, format_document + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import BasePromptTemplate, format_document def _get_default_document_prompt() -> PromptTemplate: @@ -35,7 +36,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): .. code-block:: python from langchain.chains import RefineDocumentsChain, LLMChain - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI # This controls how each document will be formatted. Specifically, diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 063efe6244d..8c028a9df19 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -2,15 +2,16 @@ from typing import Any, Dict, List, Optional, Tuple +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema import BasePromptTemplate, format_document + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, ) from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import BasePromptTemplate, format_document def _get_default_document_prompt() -> PromptTemplate: @@ -30,7 +31,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): .. code-block:: python from langchain.chains import StuffDocumentsChain, LLMChain - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI # This controls how each document will be formatted. Specifically, diff --git a/libs/langchain/langchain/chains/constitutional_ai/base.py b/libs/langchain/langchain/chains/constitutional_ai/base.py index 7bdd2809226..8c2b26ecc60 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/base.py +++ b/libs/langchain/langchain/chains/constitutional_ai/base.py @@ -1,14 +1,15 @@ """Chain for applying constitutional principles to the outputs of another chain.""" from typing import Any, Dict, List, Optional +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.principles import PRINCIPLES from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.llm import LLMChain -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class ConstitutionalChain(Chain): diff --git a/libs/langchain/langchain/chains/constitutional_ai/models.py b/libs/langchain/langchain/chains/constitutional_ai/models.py index 74e6a562530..97ea1823751 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/models.py +++ b/libs/langchain/langchain/chains/constitutional_ai/models.py @@ -1,5 +1,5 @@ """Models for the Constitutional AI chain.""" -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel class ConstitutionalPrinciple(BaseModel): diff --git a/libs/langchain/langchain/chains/constitutional_ai/prompts.py b/libs/langchain/langchain/chains/constitutional_ai/prompts.py index 54501dc720d..5e9c933b566 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/prompts.py +++ b/libs/langchain/langchain/chains/constitutional_ai/prompts.py @@ -1,8 +1,8 @@ # flake8: noqa from copy import deepcopy -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate critique_example = PromptTemplate( template="""Human: {input_prompt} diff --git a/libs/langchain/langchain/chains/conversation/base.py b/libs/langchain/langchain/chains/conversation/base.py index 06e5481c306..f9575d9133a 100644 --- a/libs/langchain/langchain/chains/conversation/base.py +++ b/libs/langchain/langchain/chains/conversation/base.py @@ -1,11 +1,12 @@ """Chain that carries on a conversation and calls an LLM.""" from typing import Dict, List +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema import BaseMemory, BasePromptTemplate + from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain from langchain.memory.buffer import ConversationBufferMemory -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import BaseMemory, BasePromptTemplate class ConversationChain(LLMChain): diff --git a/libs/langchain/langchain/chains/conversation/prompt.py b/libs/langchain/langchain/chains/conversation/prompt.py index 3209a9da97d..04dc9c2f738 100644 --- a/libs/langchain/langchain/chains/conversation/prompt.py +++ b/libs/langchain/langchain/chains/conversation/prompt.py @@ -6,7 +6,7 @@ from langchain.memory.prompt import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, SUMMARY_PROMPT, ) -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 6812fb3907e..3624c874733 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -7,6 +7,13 @@ from abc import abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.runnables.config import RunnableConfig +from langchain_core.schema import BasePromptTemplate, BaseRetriever, Document +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -18,12 +25,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema import BasePromptTemplate, BaseRetriever, Document -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage -from langchain.schema.runnable.config import RunnableConfig -from langchain.schema.vectorstore import VectorStore # Depending on the memory type and configuration, the chat history format may differ. # This needs to be consolidated. @@ -247,7 +248,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): from langchain.chains import ( StuffDocumentsChain, LLMChain, ConversationalRetrievalChain ) - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate from langchain.llms import OpenAI combine_docs_chain = StuffDocumentsChain(...) diff --git a/libs/langchain/langchain/chains/conversational_retrieval/prompts.py b/libs/langchain/langchain/chains/conversational_retrieval/prompts.py index 537f9378535..f0e5aae0139 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/prompts.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 5ee33c9fb99..49ee74cb2d9 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -3,14 +3,15 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BaseLLMOutputParser, BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT from langchain.chains.llm import LLMChain from langchain.output_parsers.json import SimpleJsonOutputParser -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BaseLLMOutputParser, BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel if TYPE_CHECKING: from elasticsearch import Elasticsearch diff --git a/libs/langchain/langchain/chains/elasticsearch_database/prompts.py b/libs/langchain/langchain/chains/elasticsearch_database/prompts.py index 9d9b6b00fec..196d97a9db9 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/prompts.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate PROMPT_SUFFIX = """Only use the following Elasticsearch indices: {indices_info} diff --git a/libs/langchain/langchain/chains/example_generator.py b/libs/langchain/langchain/chains/example_generator.py index c01cba46671..da84ad9b493 100644 --- a/libs/langchain/langchain/chains/example_generator.py +++ b/libs/langchain/langchain/chains/example_generator.py @@ -1,9 +1,10 @@ from typing import List +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.llm import LLMChain -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.schema.language_model import BaseLanguageModel TEST_GEN_TEMPLATE_SUFFIX = "Add another example." diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 198da0fcaf5..e11b3e0648a 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -5,6 +5,9 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate, BaseRetriever, Generation +from langchain_core.schema.language_model import BaseLanguageModel from langchain.callbacks.manager import ( CallbackManagerForChainRun, @@ -17,9 +20,6 @@ from langchain.chains.flare.prompts import ( ) from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate, BaseRetriever, Generation -from langchain.schema.language_model import BaseLanguageModel class _ResponseChain(LLMChain): diff --git a/libs/langchain/langchain/chains/flare/prompts.py b/libs/langchain/langchain/chains/flare/prompts.py index c6ef2e0383b..b5e63d2a97b 100644 --- a/libs/langchain/langchain/chains/flare/prompts.py +++ b/libs/langchain/langchain/chains/flare/prompts.py @@ -1,7 +1,7 @@ from typing import Tuple -from langchain.prompts import PromptTemplate -from langchain.schema import BaseOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.schema import BaseOutputParser class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]): diff --git a/libs/langchain/langchain/chains/graph_qa/arangodb.py b/libs/langchain/langchain/chains/graph_qa/arangodb.py index 48066c8ee1f..36aa44c1869 100644 --- a/libs/langchain/langchain/chains/graph_qa/arangodb.py +++ b/libs/langchain/langchain/chains/graph_qa/arangodb.py @@ -4,6 +4,9 @@ from __future__ import annotations import re from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate + from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -14,8 +17,6 @@ from langchain.chains.graph_qa.prompts import ( ) from langchain.chains.llm import LLMChain from langchain.graphs.arangodb_graph import ArangoGraph -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate class ArangoGraphQAChain(Chain): diff --git a/libs/langchain/langchain/chains/graph_qa/base.py b/libs/langchain/langchain/chains/graph_qa/base.py index 8543c9a775b..f38902d1ab3 100644 --- a/libs/langchain/langchain/chains/graph_qa/base.py +++ b/libs/langchain/langchain/chains/graph_qa/base.py @@ -3,14 +3,15 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class GraphQAChain(Chain): diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index 3531d517f90..43a24955a18 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -4,15 +4,16 @@ from __future__ import annotations import re from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.graph_store import GraphStore -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel INTERMEDIATE_STEPS_KEY = "intermediate_steps" diff --git a/libs/langchain/langchain/chains/graph_qa/falkordb.py b/libs/langchain/langchain/chains/graph_qa/falkordb.py index f973b876b26..7bf0311848d 100644 --- a/libs/langchain/langchain/chains/graph_qa/falkordb.py +++ b/libs/langchain/langchain/chains/graph_qa/falkordb.py @@ -4,14 +4,15 @@ from __future__ import annotations import re from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate + from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs import FalkorDBGraph -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate INTERMEDIATE_STEPS_KEY = "intermediate_steps" diff --git a/libs/langchain/langchain/chains/graph_qa/hugegraph.py b/libs/langchain/langchain/chains/graph_qa/hugegraph.py index 3618dedebb1..add6a0c3ad9 100644 --- a/libs/langchain/langchain/chains/graph_qa/hugegraph.py +++ b/libs/langchain/langchain/chains/graph_qa/hugegraph.py @@ -3,6 +3,10 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ( @@ -11,9 +15,6 @@ from langchain.chains.graph_qa.prompts import ( ) from langchain.chains.llm import LLMChain from langchain.graphs.hugegraph import HugeGraph -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class HugeGraphQAChain(Chain): diff --git a/libs/langchain/langchain/chains/graph_qa/kuzu.py b/libs/langchain/langchain/chains/graph_qa/kuzu.py index 2b63edae6b3..a04eb9ef67a 100644 --- a/libs/langchain/langchain/chains/graph_qa/kuzu.py +++ b/libs/langchain/langchain/chains/graph_qa/kuzu.py @@ -3,14 +3,15 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.kuzu_graph import KuzuGraph -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class KuzuQAChain(Chain): diff --git a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py index 9bd5f5ec430..7bb966ca943 100644 --- a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py +++ b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py @@ -3,14 +3,15 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.nebula_graph import NebulaGraph -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class NebulaGraphQAChain(Chain): diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index ec55a93f0cb..0627ced5808 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -3,6 +3,9 @@ from __future__ import annotations import re from typing import Any, Dict, List, Optional +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -14,8 +17,6 @@ from langchain.chains.graph_qa.prompts import ( from langchain.chains.llm import LLMChain from langchain.chains.prompt_selector import ConditionalPromptSelector from langchain.graphs import NeptuneGraph -from langchain.prompts.base import BasePromptTemplate -from langchain.pydantic_v1 import Field INTERMEDIATE_STEPS_KEY = "intermediate_steps" diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index 193626376b6..6ca1e70266c 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places. diff --git a/libs/langchain/langchain/chains/graph_qa/sparql.py b/libs/langchain/langchain/chains/graph_qa/sparql.py index 98dee88f3b8..7f6336537ec 100644 --- a/libs/langchain/langchain/chains/graph_qa/sparql.py +++ b/libs/langchain/langchain/chains/graph_qa/sparql.py @@ -5,6 +5,10 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ( @@ -15,9 +19,6 @@ from langchain.chains.graph_qa.prompts import ( ) from langchain.chains.llm import LLMChain from langchain.graphs.rdf_graph import RdfGraph -from langchain.prompts.base import BasePromptTemplate -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel class GraphSparqlQAChain(Chain): diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 2c4794e001a..07b4e4de7da 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -7,14 +7,14 @@ from __future__ import annotations from typing import Any, Dict, List, Optional import numpy as np +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.language_model import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import Extra -from langchain.schema.embeddings import Embeddings -from langchain.schema.language_model import BaseLanguageModel class HypotheticalDocumentEmbedder(Chain, Embeddings): diff --git a/libs/langchain/langchain/chains/hyde/prompts.py b/libs/langchain/langchain/chains/hyde/prompts.py index 746cce3a1db..36c1cc0614b 100644 --- a/libs/langchain/langchain/chains/hyde/prompts.py +++ b/libs/langchain/langchain/chains/hyde/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate web_search_template = """Please write a passage to answer the question Question: {QUESTION} diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index 33555f4d86a..aec49e29703 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -4,18 +4,17 @@ from __future__ import annotations import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForChainRun, - CallbackManager, - CallbackManagerForChainRun, - Callbacks, +from langchain_core.load.dump import dumpd +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.runnables import ( + Runnable, + RunnableBinding, + RunnableBranch, + RunnableWithFallbacks, ) -from langchain.chains.base import Chain -from langchain.load.dump import dumpd -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, Field -from langchain.schema import ( +from langchain_core.runnables.configurable import DynamicRunnable +from langchain_core.schema import ( BaseLLMOutputParser, BaseMessage, BasePromptTemplate, @@ -25,18 +24,20 @@ from langchain.schema import ( PromptValue, StrOutputParser, ) -from langchain.schema.language_model import ( +from langchain_core.schema.language_model import ( BaseLanguageModel, LanguageModelInput, ) -from langchain.schema.runnable import ( - Runnable, - RunnableBinding, - RunnableBranch, - RunnableWithFallbacks, +from langchain_core.utils.input import get_colored_text + +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + Callbacks, ) -from langchain.schema.runnable.configurable import DynamicRunnable -from langchain.utils.input import get_colored_text +from langchain.chains.base import Chain class LLMChain(Chain): @@ -47,7 +48,7 @@ class LLMChain(Chain): from langchain.chains import LLMChain from langchain.llms import OpenAI - from langchain.prompts import PromptTemplate + from langchain_core.prompts import PromptTemplate prompt_template = "Tell me a {adjective} joke" prompt = PromptTemplate( input_variables=["adjective"], template=prompt_template diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index a9ff8178dac..10d59348e51 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -4,6 +4,10 @@ from __future__ import annotations import warnings from typing import Any, Dict, List, Optional +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -14,9 +18,6 @@ from langchain.chains.llm_checker.prompt import ( REVISED_ANSWER_PROMPT, ) from langchain.chains.sequential import SequentialChain -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema.language_model import BaseLanguageModel def _load_question_to_checked_assertions_chain( diff --git a/libs/langchain/langchain/chains/llm_checker/prompt.py b/libs/langchain/langchain/chains/llm_checker/prompt.py index 73c883d0c20..8eb5fdaf229 100644 --- a/libs/langchain/langchain/chains/llm_checker/prompt.py +++ b/libs/langchain/langchain/chains/llm_checker/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n""" CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate( diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index 58d3d31f773..15a6683cfd7 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -6,6 +6,10 @@ import re import warnings from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -13,9 +17,6 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class LLMMathChain(Chain): diff --git a/libs/langchain/langchain/chains/llm_math/prompt.py b/libs/langchain/langchain/chains/llm_math/prompt.py index 86595553322..8c0fd9e8bdc 100644 --- a/libs/langchain/langchain/chains/llm_math/prompt.py +++ b/libs/langchain/langchain/chains/llm_math/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's numexpr library. Use the output of running this code to answer the question. diff --git a/libs/langchain/langchain/chains/llm_requests.py b/libs/langchain/langchain/chains/llm_requests.py index 4abe365106e..0dcb2b3cecf 100644 --- a/libs/langchain/langchain/chains/llm_requests.py +++ b/libs/langchain/langchain/chains/llm_requests.py @@ -3,10 +3,11 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains import LLMChain from langchain.chains.base import Chain -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utilities.requests import TextRequestsWrapper DEFAULT_HEADERS = { diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index 282b4318003..d075faff535 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -6,13 +6,14 @@ import warnings from pathlib import Path from typing import Any, Dict, List, Optional +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema.language_model import BaseLanguageModel PROMPTS_DIR = Path(__file__).parent / "prompts" diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 414719d49ed..963fb5383bd 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -4,6 +4,11 @@ from pathlib import Path from typing import Any, Union import yaml +from langchain_core.prompts.loading import ( + _load_output_parser, + load_prompt, + load_prompt_from_config, +) from langchain.chains import ReduceDocumentsChain from langchain.chains.api.base import APIChain @@ -23,11 +28,6 @@ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesCha from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config -from langchain.prompts.loading import ( - _load_output_parser, - load_prompt, - load_prompt_from_config, -) from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" diff --git a/libs/langchain/langchain/chains/mapreduce.py b/libs/langchain/langchain/chains/mapreduce.py index 0374d12da59..a3fe73319a9 100644 --- a/libs/langchain/langchain/chains/mapreduce.py +++ b/libs/langchain/langchain/chains/mapreduce.py @@ -7,6 +7,10 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -15,9 +19,6 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import Extra -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.text_splitter import TextSplitter diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 6c31a957f25..c2935f3feb3 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -1,9 +1,10 @@ """Pass input through a moderation endpoint.""" from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.pydantic_v1 import root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index e6c334f5fff..d02bd43c5ad 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -4,13 +4,14 @@ from __future__ import annotations import warnings from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT from langchain.llms.openai import OpenAI -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema.language_model import BaseLanguageModel class NatBotChain(Chain): diff --git a/libs/langchain/langchain/chains/natbot/prompt.py b/libs/langchain/langchain/chains/natbot/prompt.py index 3bbda35bab9..82a35f58c7e 100644 --- a/libs/langchain/langchain/chains/natbot/prompt.py +++ b/libs/langchain/langchain/chains/natbot/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _PROMPT_TEMPLATE = """ You are an agents controlling a browser. You are given: diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 80c6d95565b..5f2acff3090 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -13,6 +13,15 @@ from typing import ( cast, ) +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable +from langchain_core.schema import BaseLLMOutputParser +from langchain_core.schema.output_parser import ( + BaseGenerationOutputParser, + BaseOutputParser, +) + from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.output_parsers.openai_functions import ( @@ -20,11 +29,6 @@ from langchain.output_parsers.openai_functions import ( PydanticAttrOutputFunctionsParser, PydanticOutputFunctionsParser, ) -from langchain.prompts import BasePromptTemplate -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BaseLLMOutputParser -from langchain.schema.output_parser import BaseGenerationOutputParser, BaseOutputParser -from langchain.schema.runnable import Runnable from langchain.utils.openai_functions import convert_pydantic_to_openai_function PYTHON_TO_JSON_TYPES = { @@ -236,8 +240,8 @@ def create_openai_fn_runnable( from langchain.chains.openai_functions import create_openai_fn_chain from langchain.chat_models import ChatOpenAI - from langchain.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.pydantic_v1 import BaseModel, Field class RecordPerson(BaseModel): @@ -310,8 +314,8 @@ def create_structured_output_runnable( from langchain.chains.openai_functions import create_structured_output_chain from langchain.chat_models import ChatOpenAI - from langchain.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.pydantic_v1 import BaseModel, Field class Dog(BaseModel): \"\"\"Identifying information about a dog.\"\"\" @@ -407,9 +411,9 @@ def create_openai_fn_chain( from langchain.chains.openai_functions import create_openai_fn_chain from langchain.chat_models import ChatOpenAI - from langchain.prompts import ChatPromptTemplate + from langchain_core.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field + from langchain_core.pydantic_v1 import BaseModel, Field class RecordPerson(BaseModel): @@ -494,9 +498,9 @@ def create_structured_output_chain( from langchain.chains.openai_functions import create_structured_output_chain from langchain.chat_models import ChatOpenAI - from langchain.prompts import ChatPromptTemplate + from langchain_core.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field + from langchain_core.pydantic_v1 import BaseModel, Field class Dog(BaseModel): \"\"\"Identifying information about a dog.\"\"\" diff --git a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py index 421a7599569..4aac9a98588 100644 --- a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -1,14 +1,15 @@ from typing import Iterator, List +from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import HumanMessage, SystemMessage + from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs from langchain.output_parsers.openai_functions import ( PydanticOutputFunctionsParser, ) -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import HumanMessage, SystemMessage class FactWithEvidence(BaseModel): diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index b3beb9e3688..cf6b1eea290 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -1,5 +1,10 @@ from typing import Any, List, Optional +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import ( @@ -11,10 +16,6 @@ from langchain.output_parsers.openai_functions import ( JsonKeyOutputFunctionsParser, PydanticAttrOutputFunctionsParser, ) -from langchain.prompts import ChatPromptTemplate -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel def _get_extraction_function(entity_schema: dict) -> dict: diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py index 32a064d2c27..1c23c39ca75 100644 --- a/libs/langchain/langchain/chains/openai_functions/openapi.py +++ b/libs/langchain/langchain/chains/openai_functions/openapi.py @@ -6,6 +6,10 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import requests +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.utils.input import get_colored_text from requests import Response from langchain.callbacks.manager import CallbackManagerForChainRun @@ -14,12 +18,8 @@ from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain from langchain.chat_models import ChatOpenAI from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser -from langchain.prompts import ChatPromptTemplate -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.tools import APIOperation from langchain.utilities.openapi import OpenAPISpec -from langchain.utils.input import get_colored_text if TYPE_CHECKING: from openapi_pydantic import Parameter diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index 770b3e9d76b..cf5a0e68466 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -1,17 +1,18 @@ from typing import Any, List, Optional, Type, Union +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BaseLLMOutputParser +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import HumanMessage, SystemMessage + from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import get_llm_kwargs from langchain.output_parsers.openai_functions import ( OutputFunctionsParser, PydanticOutputFunctionsParser, ) -from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BaseLLMOutputParser -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import HumanMessage, SystemMessage class AnswerWithSources(BaseModel): diff --git a/libs/langchain/langchain/chains/openai_functions/tagging.py b/libs/langchain/langchain/chains/openai_functions/tagging.py index c9498a73085..f5a5a423b27 100644 --- a/libs/langchain/langchain/chains/openai_functions/tagging.py +++ b/libs/langchain/langchain/chains/openai_functions/tagging.py @@ -1,5 +1,8 @@ from typing import Any, Optional +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs @@ -7,8 +10,6 @@ from langchain.output_parsers.openai_functions import ( JsonOutputFunctionsParser, PydanticOutputFunctionsParser, ) -from langchain.prompts import ChatPromptTemplate -from langchain.schema.language_model import BaseLanguageModel def _get_tagging_function(schema: dict) -> dict: diff --git a/libs/langchain/langchain/chains/openai_tools/extraction.py b/libs/langchain/langchain/chains/openai_tools/extraction.py index e5aee76022c..95bb3e1bf68 100644 --- a/libs/langchain/langchain/chains/openai_tools/extraction.py +++ b/libs/langchain/langchain/chains/openai_tools/extraction.py @@ -1,10 +1,11 @@ from typing import List, Type, Union +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.output_parsers import PydanticToolsParser -from langchain.prompts import ChatPromptTemplate -from langchain.pydantic_v1 import BaseModel -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.runnable import Runnable from langchain.utils.openai_functions import convert_pydantic_to_openai_function _EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \ diff --git a/libs/langchain/langchain/chains/prompt_selector.py b/libs/langchain/langchain/chains/prompt_selector.py index 7726ed7ebf0..19aea968867 100644 --- a/libs/langchain/langchain/chains/prompt_selector.py +++ b/libs/langchain/langchain/chains/prompt_selector.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod from typing import Callable, List, Tuple +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class BasePromptSelector(BaseModel, ABC): diff --git a/libs/langchain/langchain/chains/qa_generation/base.py b/libs/langchain/langchain/chains/qa_generation/base.py index c081060086f..2890e85607f 100644 --- a/libs/langchain/langchain/chains/qa_generation/base.py +++ b/libs/langchain/langchain/chains/qa_generation/base.py @@ -3,13 +3,14 @@ from __future__ import annotations import json from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/libs/langchain/langchain/chains/qa_generation/prompt.py b/libs/langchain/langchain/chains/qa_generation/prompt.py index 3919c2a2395..377a49e4c9b 100644 --- a/libs/langchain/langchain/chains/qa_generation/prompt.py +++ b/libs/langchain/langchain/chains/qa_generation/prompt.py @@ -1,11 +1,11 @@ # flake8: noqa from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model -from langchain.prompts.chat import ( +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate templ1 = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions. Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities. diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index 3e3023ce2b2..8997d487f4b 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -7,6 +7,10 @@ import re from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -24,9 +28,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import ( QUESTION_PROMPT, ) from langchain.docstore.document import Document -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class BaseQAWithSourcesChain(Chain, ABC): diff --git a/libs/langchain/langchain/chains/qa_with_sources/loading.py b/libs/langchain/langchain/chains/qa_with_sources/loading.py index f6239f49603..5c9e12d0feb 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/loading.py +++ b/libs/langchain/langchain/chains/qa_with_sources/loading.py @@ -3,6 +3,9 @@ from __future__ import annotations from typing import Any, Mapping, Optional, Protocol +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.prompt_template import BasePromptTemplate + from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -18,8 +21,6 @@ from langchain.chains.qa_with_sources import ( from langchain.chains.question_answering.map_rerank_prompt import ( PROMPT as MAP_RERANK_PROMPT, ) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.prompt_template import BasePromptTemplate class LoadingCallable(Protocol): diff --git a/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py b/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py index 8cafe7ecfbf..e0c8545e704 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. diff --git a/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py b/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py index 6920b6bba18..2e13f54153b 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py +++ b/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate DEFAULT_REFINE_PROMPT_TMPL = ( "The original question is as follows: {question}\n" diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index d47c43e51cf..7f6f56cccc9 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -2,6 +2,9 @@ from typing import Any, Dict, List +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -9,8 +12,6 @@ from langchain.callbacks.manager import ( from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import Field -from langchain.schema import BaseRetriever class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): diff --git a/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py b/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py index b2112fa12b0..82290ee0507 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py +++ b/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). If you don't know the answer, just say that you don't know. Don't try to make up an answer. diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index 8bb432ce88f..1817fb7abc1 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -3,6 +3,9 @@ import warnings from typing import Any, Dict, List +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -10,8 +13,6 @@ from langchain.callbacks.manager import ( from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.vectorstore import VectorStore class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 7c91b3fa059..4f3b69bed5b 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -4,6 +4,15 @@ from __future__ import annotations import json from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast +from langchain_core.prompts.few_shot import FewShotPromptTemplate +from langchain_core.runnables import Runnable +from langchain_core.schema import ( + BaseOutputParser, + BasePromptTemplate, + OutputParserException, +) +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.llm import LLMChain from langchain.chains.query_constructor.ir import ( Comparator, @@ -28,10 +37,6 @@ from langchain.chains.query_constructor.prompt import ( ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.json import parse_and_check_json_markdown -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.runnable import Runnable class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): diff --git a/libs/langchain/langchain/chains/query_constructor/ir.py b/libs/langchain/langchain/chains/query_constructor/ir.py index 9f412aedb29..8c8cfaa4563 100644 --- a/libs/langchain/langchain/chains/query_constructor/ir.py +++ b/libs/langchain/langchain/chains/query_constructor/ir.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, List, Optional, Sequence, Union -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel class Visitor(ABC): diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index a2f7db6cbd0..26c5360d59f 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -2,10 +2,9 @@ import datetime import warnings from typing import Any, Literal, Optional, Sequence, Union +from langchain_core.utils import check_package_version from typing_extensions import TypedDict -from langchain.utils import check_package_version - try: check_package_version("lark", gte_version="1.1.5") from lark import Lark, Transformer, v_args diff --git a/libs/langchain/langchain/chains/query_constructor/prompt.py b/libs/langchain/langchain/chains/query_constructor/prompt.py index c3764b2aaed..d1355b32663 100644 --- a/libs/langchain/langchain/chains/query_constructor/prompt.py +++ b/libs/langchain/langchain/chains/query_constructor/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate SONG_DATA_SOURCE = """\ ```json diff --git a/libs/langchain/langchain/chains/query_constructor/schema.py b/libs/langchain/langchain/chains/query_constructor/schema.py index 54998710615..6171b3742f2 100644 --- a/libs/langchain/langchain/chains/query_constructor/schema.py +++ b/libs/langchain/langchain/chains/query_constructor/schema.py @@ -1,4 +1,4 @@ -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel class AttributeInfo(BaseModel): diff --git a/libs/langchain/langchain/chains/question_answering/__init__.py b/libs/langchain/langchain/chains/question_answering/__init__.py index 8cb0cad82b5..1625e4e7482 100644 --- a/libs/langchain/langchain/chains/question_answering/__init__.py +++ b/libs/langchain/langchain/chains/question_answering/__init__.py @@ -1,6 +1,9 @@ """Load question answering chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.prompt_template import BasePromptTemplate + from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.chains import ReduceDocumentsChain @@ -18,8 +21,6 @@ from langchain.chains.question_answering import ( from langchain.chains.question_answering.map_rerank_prompt import ( PROMPT as MAP_RERANK_PROMPT, ) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.prompt_template import BasePromptTemplate class LoadingCallable(Protocol): diff --git a/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py b/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py index 9b6153f9e80..defaa8fea7b 100644 --- a/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py @@ -1,11 +1,11 @@ # flake8: noqa from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model -from langchain.prompts.chat import ( +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. diff --git a/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py b/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py index c8041c6c3af..f9547385f92 100644 --- a/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa from langchain.output_parsers.regex import RegexParser -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate output_parser = RegexParser( regex=r"(.*?)\nScore: (\d*)", diff --git a/libs/langchain/langchain/chains/question_answering/refine_prompts.py b/libs/langchain/langchain/chains/question_answering/refine_prompts.py index d375b948e44..ed4d4417fb7 100644 --- a/libs/langchain/langchain/chains/question_answering/refine_prompts.py +++ b/libs/langchain/langchain/chains/question_answering/refine_prompts.py @@ -1,12 +1,12 @@ # flake8: noqa from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model -from langchain.prompts.chat import ( +from langchain_core.prompts.chat import ( AIMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate DEFAULT_REFINE_PROMPT_TMPL = ( "The original question is as follows: {question}\n" diff --git a/libs/langchain/langchain/chains/question_answering/stuff_prompt.py b/libs/langchain/langchain/chains/question_answering/stuff_prompt.py index 982ccc35d29..ee006433852 100644 --- a/libs/langchain/langchain/chains/question_answering/stuff_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/stuff_prompt.py @@ -1,7 +1,7 @@ # flake8: noqa from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model -from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ( +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 4828229d632..2fccea39faa 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -6,6 +6,12 @@ import warnings from abc import abstractmethod from typing import Any, Dict, List, Optional +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -17,11 +23,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import BaseRetriever, Document -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore class BaseRetrievalQA(Chain): @@ -198,7 +199,7 @@ class RetrievalQA(BaseRetrievalQA): from langchain.llms import OpenAI from langchain.chains import RetrievalQA from langchain.vectorstores import FAISS - from langchain.schema.vectorstore import VectorStoreRetriever + from langchain_core.schema.vectorstore import VectorStoreRetriever retriever = VectorStoreRetriever(vectorstore=FAISS(...)) retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) diff --git a/libs/langchain/langchain/chains/retrieval_qa/prompt.py b/libs/langchain/langchain/chains/retrieval_qa/prompt.py index 9ebb89eac92..963c353184e 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/prompt.py +++ b/libs/langchain/langchain/chains/retrieval_qa/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index 6bc704b8e20..c2acb27f915 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -4,13 +4,14 @@ from __future__ import annotations from abc import ABC from typing import Any, Dict, List, Mapping, NamedTuple, Optional +from langchain_core.pydantic_v1 import Extra + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks, ) from langchain.chains.base import Chain -from langchain.pydantic_v1 import Extra class Route(NamedTuple): diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index 1f7a716076f..69432177f69 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -2,12 +2,13 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, Type +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.router.base import RouterChain from langchain.docstore.document import Document -from langchain.pydantic_v1 import Extra -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore class EmbeddingRouterChain(RouterChain): diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index 74dcf6c0721..a6c8ddc05e3 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -3,6 +3,14 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Type, cast +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ( + BaseOutputParser, + BasePromptTemplate, + OutputParserException, +) +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -10,9 +18,6 @@ from langchain.callbacks.manager import ( from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain from langchain.output_parsers.json import parse_and_check_json_markdown -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException -from langchain.schema.language_model import BaseLanguageModel class LLMRouterChain(RouterChain): diff --git a/libs/langchain/langchain/chains/router/multi_prompt.py b/libs/langchain/langchain/chains/router/multi_prompt.py index f4031b968f2..c28a5d279dc 100644 --- a/libs/langchain/langchain/chains/router/multi_prompt.py +++ b/libs/langchain/langchain/chains/router/multi_prompt.py @@ -3,14 +3,15 @@ from __future__ import annotations from typing import Any, Dict, List, Optional +from langchain_core.prompts import PromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains import ConversationChain from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.router.base import MultiRouteChain from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE -from langchain.prompts import PromptTemplate -from langchain.schema.language_model import BaseLanguageModel class MultiPromptChain(MultiRouteChain): diff --git a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py index 183a87bb714..01b52e7d521 100644 --- a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py +++ b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py @@ -3,6 +3,10 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional +from langchain_core.prompts import PromptTemplate +from langchain_core.schema import BaseRetriever +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains import ConversationChain from langchain.chains.base import Chain from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE @@ -13,9 +17,6 @@ from langchain.chains.router.multi_retrieval_prompt import ( MULTI_RETRIEVAL_ROUTER_TEMPLATE, ) from langchain.chat_models import ChatOpenAI -from langchain.prompts import PromptTemplate -from langchain.schema import BaseRetriever -from langchain.schema.language_model import BaseLanguageModel class MultiRetrievalQAChain(MultiRouteChain): diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index 35461013ce4..d9462434fd4 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -1,13 +1,14 @@ """Chain pipeline where the outputs of one step feed directly into next.""" from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.utils.input import get_color_mapping + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from langchain.chains.base import Chain -from langchain.pydantic_v1 import Extra, root_validator -from langchain.utils.input import get_color_mapping class SequentialChain(Chain): diff --git a/libs/langchain/langchain/chains/sql_database/prompt.py b/libs/langchain/langchain/chains/sql_database/prompt.py index b212ecd3067..34cd2307b6b 100644 --- a/libs/langchain/langchain/chains/sql_database/prompt.py +++ b/libs/langchain/langchain/chains/sql_database/prompt.py @@ -1,6 +1,6 @@ # flake8: noqa from langchain.output_parsers.list import CommaSeparatedListOutputParser -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate PROMPT_SUFFIX = """Only use the following tables: diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py index 99c0fff0a32..b3d19f28575 100644 --- a/libs/langchain/langchain/chains/sql_database/query.py +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -1,10 +1,11 @@ from typing import List, Optional, TypedDict, Union +from langchain_core.runnables import Runnable, RunnableParallel +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.output_parser import NoOpOutputParser +from langchain_core.schema.prompt_template import BasePromptTemplate + from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.output_parser import NoOpOutputParser -from langchain.schema.prompt_template import BasePromptTemplate -from langchain.schema.runnable import Runnable, RunnableParallel from langchain.utilities.sql_database import SQLDatabase diff --git a/libs/langchain/langchain/chains/summarize/__init__.py b/libs/langchain/langchain/chains/summarize/__init__.py index 681019b107d..4d692d30e0c 100644 --- a/libs/langchain/langchain/chains/summarize/__init__.py +++ b/libs/langchain/langchain/chains/summarize/__init__.py @@ -1,6 +1,9 @@ """Load summarizing chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -9,8 +12,6 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py b/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py index 3cd9f941f43..3cf06395c63 100644 --- a/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate prompt_template = """Write a concise summary of the following: diff --git a/libs/langchain/langchain/chains/summarize/refine_prompts.py b/libs/langchain/langchain/chains/summarize/refine_prompts.py index 013d0919f24..63c1c33880c 100644 --- a/libs/langchain/langchain/chains/summarize/refine_prompts.py +++ b/libs/langchain/langchain/chains/summarize/refine_prompts.py @@ -1,4 +1,4 @@ -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate REFINE_PROMPT_TMPL = """\ Your job is to produce a final summary. diff --git a/libs/langchain/langchain/chains/summarize/stuff_prompt.py b/libs/langchain/langchain/chains/summarize/stuff_prompt.py index 3cd9f941f43..3cf06395c63 100644 --- a/libs/langchain/langchain/chains/summarize/stuff_prompt.py +++ b/libs/langchain/langchain/chains/summarize/stuff_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate prompt_template = """Write a concise summary of the following: diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index bdf51bf5c09..e251cff2c93 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -3,12 +3,13 @@ import functools import logging from typing import Any, Awaitable, Callable, Dict, List, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from langchain.chains.base import Chain -from langchain.pydantic_v1 import Field logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_loaders/base.py b/libs/langchain/langchain/chat_loaders/base.py index 63203588d51..87d3131976b 100644 --- a/libs/langchain/langchain/chat_loaders/base.py +++ b/libs/langchain/langchain/chat_loaders/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Iterator, List -from langchain.schema.chat import ChatSession +from langchain_core.schema.chat import ChatSession class BaseChatLoader(ABC): diff --git a/libs/langchain/langchain/chat_loaders/facebook_messenger.py b/libs/langchain/langchain/chat_loaders/facebook_messenger.py index 644133f1bfc..52fbec5ec47 100644 --- a/libs/langchain/langchain/chat_loaders/facebook_messenger.py +++ b/libs/langchain/langchain/chat_loaders/facebook_messenger.py @@ -3,9 +3,10 @@ import logging from pathlib import Path from typing import Iterator, Union +from langchain_core.schema.chat import ChatSession +from langchain_core.schema.messages import HumanMessage + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema.chat import ChatSession -from langchain.schema.messages import HumanMessage logger = logging.getLogger(__file__) diff --git a/libs/langchain/langchain/chat_loaders/gmail.py b/libs/langchain/langchain/chat_loaders/gmail.py index f4e57d92412..b22204110e9 100644 --- a/libs/langchain/langchain/chat_loaders/gmail.py +++ b/libs/langchain/langchain/chat_loaders/gmail.py @@ -2,9 +2,10 @@ import base64 import re from typing import Any, Iterator +from langchain_core.schema.chat import ChatSession +from langchain_core.schema.messages import HumanMessage + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema.chat import ChatSession -from langchain.schema.messages import HumanMessage def _extract_email_content(msg: Any) -> HumanMessage: diff --git a/libs/langchain/langchain/chat_loaders/imessage.py b/libs/langchain/langchain/chat_loaders/imessage.py index 78b3c42974d..319845486ae 100644 --- a/libs/langchain/langchain/chat_loaders/imessage.py +++ b/libs/langchain/langchain/chat_loaders/imessage.py @@ -3,9 +3,10 @@ from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, Iterator, List, Optional, Union +from langchain_core.schema import HumanMessage +from langchain_core.schema.chat import ChatSession + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema import HumanMessage -from langchain.schema.chat import ChatSession if TYPE_CHECKING: import sqlite3 diff --git a/libs/langchain/langchain/chat_loaders/langsmith.py b/libs/langchain/langchain/chat_loaders/langsmith.py index 513ea3644c0..dfe2df3521c 100644 --- a/libs/langchain/langchain/chat_loaders/langsmith.py +++ b/libs/langchain/langchain/chat_loaders/langsmith.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast +from langchain_core.load import load +from langchain_core.schema.chat import ChatSession + from langchain.chat_loaders.base import BaseChatLoader -from langchain.load import load -from langchain.schema.chat import ChatSession if TYPE_CHECKING: from langsmith.client import Client diff --git a/libs/langchain/langchain/chat_loaders/slack.py b/libs/langchain/langchain/chat_loaders/slack.py index 65791a13f5e..9d5822e590d 100644 --- a/libs/langchain/langchain/chat_loaders/slack.py +++ b/libs/langchain/langchain/chat_loaders/slack.py @@ -5,9 +5,10 @@ import zipfile from pathlib import Path from typing import Dict, Iterator, List, Union +from langchain_core.schema import AIMessage, HumanMessage +from langchain_core.schema.chat import ChatSession + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema import AIMessage, HumanMessage -from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_loaders/telegram.py b/libs/langchain/langchain/chat_loaders/telegram.py index d417ebafe15..6661fa0d80f 100644 --- a/libs/langchain/langchain/chat_loaders/telegram.py +++ b/libs/langchain/langchain/chat_loaders/telegram.py @@ -6,9 +6,10 @@ import zipfile from pathlib import Path from typing import Iterator, List, Union +from langchain_core.schema import AIMessage, BaseMessage, HumanMessage +from langchain_core.schema.chat import ChatSession + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema import AIMessage, BaseMessage, HumanMessage -from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_loaders/utils.py b/libs/langchain/langchain/chat_loaders/utils.py index b6a83c4fe1b..bd85dc529c0 100644 --- a/libs/langchain/langchain/chat_loaders/utils.py +++ b/libs/langchain/langchain/chat_loaders/utils.py @@ -2,8 +2,8 @@ from copy import deepcopy from typing import Iterable, Iterator, List -from langchain.schema.chat import ChatSession -from langchain.schema.messages import AIMessage, BaseMessage +from langchain_core.schema.chat import ChatSession +from langchain_core.schema.messages import AIMessage, BaseMessage def merge_chat_runs_in_session( diff --git a/libs/langchain/langchain/chat_loaders/whatsapp.py b/libs/langchain/langchain/chat_loaders/whatsapp.py index f8de9c0e411..36638e07c66 100644 --- a/libs/langchain/langchain/chat_loaders/whatsapp.py +++ b/libs/langchain/langchain/chat_loaders/whatsapp.py @@ -4,9 +4,10 @@ import re import zipfile from typing import Iterator, List, Union +from langchain_core.schema import AIMessage, HumanMessage +from langchain_core.schema.chat import ChatSession + from langchain.chat_loaders.base import BaseChatLoader -from langchain.schema import AIMessage, HumanMessage -from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index ad74089dd3a..b8cc37474d1 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -1,5 +1,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.schema.prompt import PromptValue + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -10,16 +21,6 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.anthropic import _AnthropicCommon -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain.schema.prompt import PromptValue def _convert_one_message_to_text( diff --git a/libs/langchain/langchain/chat_models/anyscale.py b/libs/langchain/langchain/chat_models/anyscale.py index 408c0a6ce10..a94cc48da97 100644 --- a/libs/langchain/langchain/chat_models/anyscale.py +++ b/libs/langchain/langchain/chat_models/anyscale.py @@ -7,15 +7,16 @@ import sys from typing import TYPE_CHECKING, Dict, Optional, Set import requests +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema.messages import BaseMessage +from langchain_core.utils import convert_to_secret_str from langchain.adapters.openai import convert_message_to_dict from langchain.chat_models.openai import ( ChatOpenAI, _import_tiktoken, ) -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema.messages import BaseMessage -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env from langchain.utils.openai import is_openai_v1 if TYPE_CHECKING: diff --git a/libs/langchain/langchain/chat_models/azure_openai.py b/libs/langchain/langchain/chat_models/azure_openai.py index 925a475a66e..bf045c42e93 100644 --- a/libs/langchain/langchain/chat_models/azure_openai.py +++ b/libs/langchain/langchain/chat_models/azure_openai.py @@ -6,9 +6,10 @@ import os import warnings from typing import Any, Dict, Union +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.schema import ChatResult + from langchain.chat_models.openai import ChatOpenAI -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import ChatResult from langchain.utils import get_from_dict_or_env from langchain.utils.openai import is_openai_v1 diff --git a/libs/langchain/langchain/chat_models/azureml_endpoint.py b/libs/langchain/langchain/chat_models/azureml_endpoint.py index 8efa957ad0f..089acc66051 100644 --- a/libs/langchain/langchain/chat_models/azureml_endpoint.py +++ b/libs/langchain/langchain/chat_models/azureml_endpoint.py @@ -1,18 +1,20 @@ import json from typing import Any, Dict, List, Optional, cast -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import SimpleChatModel -from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase -from langchain.pydantic_v1 import SecretStr, validator -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import SecretStr, validator +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, HumanMessage, SystemMessage, ) -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import SimpleChatModel +from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase +from langchain.utils import get_from_dict_or_env class LlamaContentFormatter(ContentFormatterBase): diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py index 611d761251f..ee5bb850bbc 100644 --- a/libs/langchain/langchain/chat_models/baichuan.py +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -5,11 +5,8 @@ import time from typing import Any, Dict, Iterator, List, Mapping, Optional, Type import requests - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import BaseChatModel, _generate_from_stream -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, @@ -17,19 +14,22 @@ from langchain.schema import ( ChatResult, HumanMessage, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessageChunk, BaseMessageChunk, ChatMessageChunk, HumanMessageChunk, ) -from langchain.schema.output import ChatGenerationChunk -from langchain.utils import ( +from langchain_core.schema.output import ChatGenerationChunk +from langchain_core.utils import ( convert_to_secret_str, - get_from_dict_or_env, get_pydantic_field_names, ) +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel, _generate_from_stream +from langchain.utils import get_from_dict_or_env + logger = logging.getLogger(__name__) DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1" diff --git a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py index fa04588a053..27f41463ed3 100644 --- a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py @@ -3,14 +3,9 @@ from __future__ import annotations import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -19,7 +14,13 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.schema.output import ChatGenerationChunk +from langchain_core.schema.output import ChatGenerationChunk + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index a92d02d6a10..686b88ece9c 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -1,735 +1,13 @@ -import asyncio -import inspect -import warnings -from abc import ABC, abstractmethod -from functools import partial -from typing import ( - Any, - AsyncIterator, - Dict, - Iterator, - List, - Optional, - Sequence, - cast, +from langchain_core.chat_model import ( + BaseChatModel, + SimpleChatModel, + _agenerate_from_stream, + _generate_from_stream, ) -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForLLMRun, - CallbackManager, - CallbackManagerForLLMRun, - Callbacks, -) -from langchain.globals import get_llm_cache -from langchain.load.dump import dumpd, dumps -from langchain.prompts.base import StringPromptValue -from langchain.prompts.chat import ChatPromptValue -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ( - ChatGeneration, - ChatResult, - LLMResult, - PromptValue, - RunInfo, -) -from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput -from langchain.schema.messages import ( - AIMessage, - AnyMessage, - BaseMessage, - BaseMessageChunk, - HumanMessage, -) -from langchain.schema.output import ChatGenerationChunk -from langchain.schema.runnable import RunnableConfig - - -def _get_verbosity() -> bool: - from langchain.globals import get_verbose - - return get_verbose() - - -def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: - generation: Optional[ChatGenerationChunk] = None - for chunk in stream: - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) - - -async def _agenerate_from_stream( - stream: AsyncIterator[ChatGenerationChunk], -) -> ChatResult: - generation: Optional[ChatGenerationChunk] = None - async for chunk in stream: - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) - - -class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): - """Base class for Chat models.""" - - cache: Optional[bool] = None - """Whether to cache the response.""" - verbose: bool = Field(default_factory=_get_verbosity) - """Whether to print out response text.""" - callbacks: Callbacks = Field(default=None, exclude=True) - """Callbacks to add to the run trace.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) - """Callback manager to add to the run trace.""" - tags: Optional[List[str]] = Field(default=None, exclude=True) - """Tags to add to the run trace.""" - metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) - """Metadata to add to the run trace.""" - - @root_validator() - def raise_deprecation(cls, values: Dict) -> Dict: - """Raise deprecation warning if callback_manager is used.""" - if values.get("callback_manager") is not None: - warnings.warn( - "callback_manager is deprecated. Please use callbacks instead.", - DeprecationWarning, - ) - values["callbacks"] = values.pop("callback_manager", None) - return values - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - # --- Runnable methods --- - - @property - def OutputType(self) -> Any: - """Get the output type for this runnable.""" - return AnyMessage - - def _convert_input(self, input: LanguageModelInput) -> PromptValue: - if isinstance(input, PromptValue): - return input - elif isinstance(input, str): - return StringPromptValue(text=input) - elif isinstance(input, list): - return ChatPromptValue(messages=input) - else: - raise ValueError( - f"Invalid input type {type(input)}. " - "Must be a PromptValue, str, or list of BaseMessages." - ) - - def invoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - config = config or {} - return cast( - ChatGeneration, - self.generate_prompt( - [self._convert_input(input)], - stop=stop, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ).generations[0][0], - ).message - - async def ainvoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - config = config or {} - llm_result = await self.agenerate_prompt( - [self._convert_input(input)], - stop=stop, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ) - return cast(ChatGeneration, llm_result.generations[0][0]).message - - def stream( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> Iterator[BaseMessageChunk]: - if type(self)._stream == BaseChatModel._stream: - # model doesn't implement streaming, so use default implementation - yield cast( - BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) - ) - else: - config = config or {} - messages = self._convert_input(input).to_messages() - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop, **kwargs} - callback_manager = CallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, - ) - (run_manager,) = callback_manager.on_chat_model_start( - dumpd(self), - [messages], - invocation_params=params, - options=options, - name=config.get("run_name"), - ) - try: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - run_manager.on_llm_error(e) - raise e - else: - run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) - - async def astream( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> AsyncIterator[BaseMessageChunk]: - if type(self)._astream == BaseChatModel._astream: - # model doesn't implement streaming, so use default implementation - yield cast( - BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) - ) - else: - config = config or {} - messages = self._convert_input(input).to_messages() - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop, **kwargs} - callback_manager = AsyncCallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, - ) - (run_manager,) = await callback_manager.on_chat_model_start( - dumpd(self), - [messages], - invocation_params=params, - options=options, - name=config.get("run_name"), - ) - try: - generation: Optional[ChatGenerationChunk] = None - async for chunk in self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - await run_manager.on_llm_error(e) - raise e - else: - await run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) - - # --- Custom methods --- - - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - return {} - - def _get_invocation_params( - self, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> dict: - params = self.dict() - params["stop"] = stop - return {**params, **kwargs} - - def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: - if self.is_lc_serializable(): - params = {**kwargs, **{"stop": stop}} - param_string = str(sorted([(k, v) for k, v in params.items()])) - llm_string = dumps(self) - return llm_string + "---" + param_string - else: - params = self._get_invocation_params(stop=stop, **kwargs) - params = {**params, **kwargs} - return str(sorted([(k, v) for k, v in params.items()])) - - def generate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> LLMResult: - """Top Level call""" - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop} - - callback_manager = CallbackManager.configure( - callbacks, - self.callbacks, - self.verbose, - tags, - self.tags, - metadata, - self.metadata, - ) - run_managers = callback_manager.on_chat_model_start( - dumpd(self), - messages, - invocation_params=params, - options=options, - name=run_name, - ) - results = [] - for i, m in enumerate(messages): - try: - results.append( - self._generate_with_cache( - m, - stop=stop, - run_manager=run_managers[i] if run_managers else None, - **kwargs, - ) - ) - except BaseException as e: - if run_managers: - run_managers[i].on_llm_error(e) - raise e - flattened_outputs = [ - LLMResult(generations=[res.generations], llm_output=res.llm_output) - for res in results - ] - llm_output = self._combine_llm_outputs([res.llm_output for res in results]) - generations = [res.generations for res in results] - output = LLMResult(generations=generations, llm_output=llm_output) - if run_managers: - run_infos = [] - for manager, flattened_output in zip(run_managers, flattened_outputs): - manager.on_llm_end(flattened_output) - run_infos.append(RunInfo(run_id=manager.run_id)) - output.run = run_infos - return output - - async def agenerate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> LLMResult: - """Top Level call""" - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop} - - callback_manager = AsyncCallbackManager.configure( - callbacks, - self.callbacks, - self.verbose, - tags, - self.tags, - metadata, - self.metadata, - ) - - run_managers = await callback_manager.on_chat_model_start( - dumpd(self), - messages, - invocation_params=params, - options=options, - name=run_name, - ) - - results = await asyncio.gather( - *[ - self._agenerate_with_cache( - m, - stop=stop, - run_manager=run_managers[i] if run_managers else None, - **kwargs, - ) - for i, m in enumerate(messages) - ], - return_exceptions=True, - ) - exceptions = [] - for i, res in enumerate(results): - if isinstance(res, BaseException): - if run_managers: - await run_managers[i].on_llm_error(res) - exceptions.append(res) - if exceptions: - if run_managers: - await asyncio.gather( - *[ - run_manager.on_llm_end( - LLMResult( - generations=[res.generations], llm_output=res.llm_output - ) - ) - for run_manager, res in zip(run_managers, results) - if not isinstance(res, Exception) - ] - ) - raise exceptions[0] - flattened_outputs = [ - LLMResult(generations=[res.generations], llm_output=res.llm_output) - for res in results - ] - llm_output = self._combine_llm_outputs([res.llm_output for res in results]) - generations = [res.generations for res in results] - output = LLMResult(generations=generations, llm_output=llm_output) - await asyncio.gather( - *[ - run_manager.on_llm_end(flattened_output) - for run_manager, flattened_output in zip( - run_managers, flattened_outputs - ) - ] - ) - if run_managers: - output.run = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] - return output - - def generate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - prompt_messages = [p.to_messages() for p in prompts] - return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) - - async def agenerate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - prompt_messages = [p.to_messages() for p in prompts] - return await self.agenerate( - prompt_messages, stop=stop, callbacks=callbacks, **kwargs - ) - - def _generate_with_cache( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - new_arg_supported = inspect.signature(self._generate).parameters.get( - "run_manager" - ) - disregard_cache = self.cache is not None and not self.cache - llm_cache = get_llm_cache() - if llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True - if self.cache is not None and self.cache: - raise ValueError( - "Asked to cache, but no cache found at `langchain.cache`." - ) - if new_arg_supported: - return self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - return self._generate(messages, stop=stop, **kwargs) - else: - llm_string = self._get_llm_string(stop=stop, **kwargs) - prompt = dumps(messages) - cache_val = llm_cache.lookup(prompt, llm_string) - if isinstance(cache_val, list): - return ChatResult(generations=cache_val) - else: - if new_arg_supported: - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - result = self._generate(messages, stop=stop, **kwargs) - llm_cache.update(prompt, llm_string, result.generations) - return result - - async def _agenerate_with_cache( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - new_arg_supported = inspect.signature(self._agenerate).parameters.get( - "run_manager" - ) - disregard_cache = self.cache is not None and not self.cache - llm_cache = get_llm_cache() - if llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True - if self.cache is not None and self.cache: - raise ValueError( - "Asked to cache, but no cache found at `langchain.cache`." - ) - if new_arg_supported: - return await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - return await self._agenerate(messages, stop=stop, **kwargs) - else: - llm_string = self._get_llm_string(stop=stop, **kwargs) - prompt = dumps(messages) - cache_val = llm_cache.lookup(prompt, llm_string) - if isinstance(cache_val, list): - return ChatResult(generations=cache_val) - else: - if new_arg_supported: - result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - result = await self._agenerate(messages, stop=stop, **kwargs) - llm_cache.update(prompt, llm_string, result.generations) - return result - - @abstractmethod - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Top Level call""" - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Top Level call""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._generate, **kwargs), messages, stop, run_manager - ) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - raise NotImplementedError() - - def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - raise NotImplementedError() - - def __call__( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> BaseMessage: - generation = self.generate( - [messages], stop=stop, callbacks=callbacks, **kwargs - ).generations[0][0] - if isinstance(generation, ChatGeneration): - return generation.message - else: - raise ValueError("Unexpected generation type") - - async def _call_async( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> BaseMessage: - result = await self.agenerate( - [messages], stop=stop, callbacks=callbacks, **kwargs - ) - generation = result.generations[0][0] - if isinstance(generation, ChatGeneration): - return generation.message - else: - raise ValueError("Unexpected generation type") - - def call_as_llm( - self, message: str, stop: Optional[List[str]] = None, **kwargs: Any - ) -> str: - return self.predict(message, stop=stop, **kwargs) - - def predict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) - result = self([HumanMessage(content=text)], stop=_stop, **kwargs) - if isinstance(result.content, str): - return result.content - else: - raise ValueError("Cannot use predict when output is not a string.") - - def predict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - if stop is None: - _stop = None - else: - _stop = list(stop) - return self(messages, stop=_stop, **kwargs) - - async def apredict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) - result = await self._call_async( - [HumanMessage(content=text)], stop=_stop, **kwargs - ) - if isinstance(result.content, str): - return result.content - else: - raise ValueError("Cannot use predict when output is not a string.") - - async def apredict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - if stop is None: - _stop = None - else: - _stop = list(stop) - return await self._call_async(messages, stop=_stop, **kwargs) - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return {} - - @property - @abstractmethod - def _llm_type(self) -> str: - """Return type of chat model.""" - - def dict(self, **kwargs: Any) -> Dict: - """Return a dictionary of the LLM.""" - starter_dict = dict(self._identifying_params) - starter_dict["_type"] = self._llm_type - return starter_dict - - -class SimpleChatModel(BaseChatModel): - """Simple Chat Model.""" - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) - message = AIMessage(content=output_str) - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - - @abstractmethod - def _call( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Simpler interface.""" - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, func) +__all__ = [ + "BaseChatModel", + "SimpleChatModel", + "_generate_from_stream", + "_agenerate_from_stream", +] diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py index 34bb97f0f79..36ed3fe6eb0 100644 --- a/libs/langchain/langchain/chat_models/bedrock.py +++ b/libs/langchain/langchain/chat_models/bedrock.py @@ -1,5 +1,9 @@ from typing import Any, Dict, Iterator, List, Optional +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult + from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) @@ -7,9 +11,6 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic from langchain.chat_models.base import BaseChatModel from langchain.chat_models.meta import convert_messages_to_prompt_llama from langchain.llms.bedrock import BedrockBase -from langchain.pydantic_v1 import Extra -from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage -from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult from langchain.utilities.anthropic import ( get_num_tokens_anthropic, get_token_ids_anthropic, diff --git a/libs/langchain/langchain/chat_models/cohere.py b/libs/langchain/langchain/chat_models/cohere.py index 632997c98bc..b5bce8ddab1 100644 --- a/libs/langchain/langchain/chat_models/cohere.py +++ b/libs/langchain/langchain/chat_models/cohere.py @@ -1,5 +1,15 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -10,15 +20,6 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.cohere import BaseCohere -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult def get_role(message: BaseMessage) -> str: @@ -102,7 +103,7 @@ class ChatCohere(BaseChatModel, BaseCohere): .. code-block:: python from langchain.chat_models import ChatCohere - from langchain.schema import HumanMessage + from langchain_core.schema import HumanMessage chat = ChatCohere(model="foo") result = chat([HumanMessage(content="Hello")]) diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 58e7647d1bd..65a037218a5 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -4,11 +4,8 @@ import threading from typing import Any, Dict, List, Mapping, Optional import requests - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, @@ -16,6 +13,9 @@ from langchain.schema import ( ChatResult, HumanMessage, ) + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/everlyai.py b/libs/langchain/langchain/chat_models/everlyai.py index 0b5dbd85d06..5b846fdce9c 100644 --- a/libs/langchain/langchain/chat_models/everlyai.py +++ b/libs/langchain/langchain/chat_models/everlyai.py @@ -5,13 +5,14 @@ import logging import sys from typing import TYPE_CHECKING, Dict, Optional, Set +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.messages import BaseMessage + from langchain.adapters.openai import convert_message_to_dict from langchain.chat_models.openai import ( ChatOpenAI, _import_tiktoken, ) -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.messages import BaseMessage from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/chat_models/fake.py b/libs/langchain/langchain/chat_models/fake.py index 3c1e50a9b52..7d7c5330bef 100644 --- a/libs/langchain/langchain/chat_models/fake.py +++ b/libs/langchain/langchain/chat_models/fake.py @@ -3,14 +3,15 @@ import asyncio import time from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import AIMessageChunk, BaseMessage +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel, SimpleChatModel -from langchain.schema import ChatResult -from langchain.schema.messages import AIMessageChunk, BaseMessage -from langchain.schema.output import ChatGeneration, ChatGenerationChunk class FakeMessagesListChatModel(BaseChatModel): diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 36a7d582369..32ff6d20b91 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -10,15 +10,8 @@ from typing import ( Union, ) -from langchain.adapters.openai import convert_message_to_dict -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -32,8 +25,16 @@ from langchain.schema.messages import ( SystemMessage, SystemMessageChunk, ) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain.utils import convert_to_secret_str +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.utils import convert_to_secret_str + +from langchain.adapters.openai import convert_message_to_dict +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.base import create_base_retry_decorator from langchain.utils.env import get_from_dict_or_env diff --git a/libs/langchain/langchain/chat_models/gigachat.py b/libs/langchain/langchain/chat_models/gigachat.py index 2b8ac7f2133..601309f99a7 100644 --- a/libs/langchain/langchain/chat_models/gigachat.py +++ b/libs/langchain/langchain/chat_models/gigachat.py @@ -1,6 +1,17 @@ import logging from typing import Any, AsyncIterator, Iterator, List, Optional +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -11,16 +22,6 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.gigachat import _BaseGigaChat -from langchain.schema import ChatResult -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/google_palm.py b/libs/langchain/langchain/chat_models/google_palm.py index 3f68d4e45b8..ed85c236f68 100644 --- a/libs/langchain/langchain/chat_models/google_palm.py +++ b/libs/langchain/langchain/chat_models/google_palm.py @@ -4,6 +4,18 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import ( + ChatGeneration, + ChatResult, +) +from langchain_core.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) from tenacity import ( before_sleep_log, retry, @@ -17,18 +29,6 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import ( - ChatGeneration, - ChatResult, -) -from langchain.schema.messages import ( - AIMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/chat_models/human.py b/libs/langchain/langchain/chat_models/human.py index f085cb2515b..98e594a979c 100644 --- a/libs/langchain/langchain/chat_models/human.py +++ b/libs/langchain/langchain/chat_models/human.py @@ -5,6 +5,14 @@ from io import StringIO from typing import Any, Callable, Dict, List, Mapping, Optional import yaml +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.messages import ( + BaseMessage, + HumanMessage, + _message_from_dict, + messages_to_dict, +) +from langchain_core.schema.output import ChatGeneration, ChatResult from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -12,14 +20,6 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Field -from langchain.schema.messages import ( - BaseMessage, - HumanMessage, - _message_from_dict, - messages_to_dict, -) -from langchain.schema.output import ChatGeneration, ChatResult def _display_messages(messages: List[BaseMessage]) -> None: diff --git a/libs/langchain/langchain/chat_models/hunyuan.py b/libs/langchain/langchain/chat_models/hunyuan.py index ffa01ef31a4..040b1007c03 100644 --- a/libs/langchain/langchain/chat_models/hunyuan.py +++ b/libs/langchain/langchain/chat_models/hunyuan.py @@ -8,11 +8,8 @@ from typing import Any, Dict, Iterator, List, Mapping, Optional, Type from urllib.parse import urlparse import requests - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import BaseChatModel, _generate_from_stream -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, @@ -20,19 +17,22 @@ from langchain.schema import ( ChatResult, HumanMessage, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessageChunk, BaseMessageChunk, ChatMessageChunk, HumanMessageChunk, ) -from langchain.schema.output import ChatGenerationChunk -from langchain.utils import ( +from langchain_core.schema.output import ChatGenerationChunk +from langchain_core.utils import ( convert_to_secret_str, - get_from_dict_or_env, get_pydantic_field_names, ) +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel, _generate_from_stream +from langchain.utils import get_from_dict_or_env + logger = logging.getLogger(__name__) DEFAULT_API_BASE = "https://hunyuan.cloud.tencent.com" diff --git a/libs/langchain/langchain/chat_models/javelin_ai_gateway.py b/libs/langchain/langchain/chat_models/javelin_ai_gateway.py index 35c2c181777..48a36a2f5a2 100644 --- a/libs/langchain/langchain/chat_models/javelin_ai_gateway.py +++ b/libs/langchain/langchain/chat_models/javelin_ai_gateway.py @@ -1,17 +1,12 @@ import logging from typing import Any, Dict, List, Mapping, Optional, cast -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import BaseModel, Extra, SecretStr -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.schema import ( ChatGeneration, ChatResult, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, @@ -20,6 +15,12 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel + logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/jinachat.py b/libs/langchain/langchain/chat_models/jinachat.py index 7b885d6482c..ee4347e5ef3 100644 --- a/libs/langchain/langchain/chat_models/jinachat.py +++ b/libs/langchain/langchain/chat_models/jinachat.py @@ -16,6 +16,26 @@ from typing import ( Union, ) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.schema.messages import ( + AIMessageChunk, + BaseMessageChunk, + ChatMessageChunk, + HumanMessageChunk, + SystemMessageChunk, +) +from langchain_core.schema.output import ChatGenerationChunk +from langchain_core.utils import get_pydantic_field_names from tenacity import ( before_sleep_log, retry, @@ -33,26 +53,7 @@ from langchain.chat_models.base import ( _agenerate_from_stream, _generate_from_stream, ) -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatGeneration, - ChatMessage, - ChatResult, - FunctionMessage, - HumanMessage, - SystemMessage, -) -from langchain.schema.messages import ( - AIMessageChunk, - BaseMessageChunk, - ChatMessageChunk, - HumanMessageChunk, - SystemMessageChunk, -) -from langchain.schema.output import ChatGenerationChunk -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/konko.py b/libs/langchain/langchain/chat_models/konko.py index 6c5c5ef2db5..aad978123df 100644 --- a/libs/langchain/langchain/chat_models/konko.py +++ b/libs/langchain/langchain/chat_models/konko.py @@ -16,6 +16,10 @@ from typing import ( ) import requests +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import AIMessageChunk, BaseMessage +from langchain_core.schema.output import ChatGenerationChunk from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict from langchain.callbacks.manager import ( @@ -23,10 +27,6 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel, _generate_from_stream from langchain.chat_models.openai import _convert_delta_to_message_chunk -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import AIMessageChunk, BaseMessage -from langchain.schema.output import ChatGenerationChunk from langchain.utils import get_from_dict_or_env DEFAULT_API_BASE = "https://api.konko.ai/v1" diff --git a/libs/langchain/langchain/chat_models/litellm.py b/libs/langchain/langchain/chat_models/litellm.py index ac1c1c9c7c8..c5b32fae61d 100644 --- a/libs/langchain/langchain/chat_models/litellm.py +++ b/libs/langchain/langchain/chat_models/litellm.py @@ -16,22 +16,12 @@ from typing import ( Union, ) -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import ( - BaseChatModel, - _agenerate_from_stream, - _generate_from_stream, -) -from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ( ChatGeneration, ChatResult, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -45,7 +35,18 @@ from langchain.schema.messages import ( SystemMessage, SystemMessageChunk, ) -from langchain.schema.output import ChatGenerationChunk +from langchain_core.schema.output import ChatGenerationChunk + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) +from langchain.llms.base import create_base_retry_decorator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/minimax.py b/libs/langchain/langchain/chat_models/minimax.py index 0fad5e24b10..4e23c17bb2c 100644 --- a/libs/langchain/langchain/chat_models/minimax.py +++ b/libs/langchain/langchain/chat_models/minimax.py @@ -2,6 +2,13 @@ import logging from typing import Any, Dict, List, Optional, cast +from langchain_core.schema import ( + AIMessage, + BaseMessage, + ChatResult, + HumanMessage, +) + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -9,12 +16,6 @@ from langchain.callbacks.manager import ( from langchain.chat_models.base import BaseChatModel from langchain.llms.minimax import MinimaxCommon from langchain.llms.utils import enforce_stop_tokens -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatResult, - HumanMessage, -) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py b/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py index a1c02abb7d8..74ec0def45e 100644 --- a/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py +++ b/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py @@ -3,17 +3,12 @@ import logging from functools import partial from typing import Any, Dict, List, Mapping, Optional -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema import ( ChatGeneration, ChatResult, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, @@ -22,6 +17,12 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel + logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/ollama.py b/libs/langchain/langchain/chat_models/ollama.py index b54d0063c77..27efd7db2c0 100644 --- a/libs/langchain/langchain/chat_models/ollama.py +++ b/libs/langchain/langchain/chat_models/ollama.py @@ -1,13 +1,8 @@ import json from typing import Any, Iterator, List, Optional -from langchain.callbacks.manager import ( - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.llms.ollama import _OllamaCommon -from langchain.schema import ChatResult -from langchain.schema.messages import ( +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -15,7 +10,13 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk +from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.ollama import _OllamaCommon def _stream_response_to_chat_generation_chunk( diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 9f25122de54..e3de65d7ebe 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -20,6 +20,25 @@ from typing import ( Union, ) +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.runnables import Runnable +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.language_model import LanguageModelInput +from langchain_core.schema.messages import ( + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, + SystemMessageChunk, + ToolMessageChunk, +) +from langchain_core.schema.output import ChatGenerationChunk +from langchain_core.utils import ( + get_pydantic_field_names, +) + from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -31,25 +50,7 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.language_model import LanguageModelInput -from langchain.schema.messages import ( - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessageChunk, - FunctionMessageChunk, - HumanMessageChunk, - SystemMessageChunk, - ToolMessageChunk, -) -from langchain.schema.output import ChatGenerationChunk -from langchain.schema.runnable import Runnable -from langchain.utils import ( - get_from_dict_or_env, - get_pydantic_field_names, -) +from langchain.utils import get_from_dict_or_env from langchain.utils.openai import is_openai_v1 if TYPE_CHECKING: diff --git a/libs/langchain/langchain/chat_models/pai_eas_endpoint.py b/libs/langchain/langchain/chat_models/pai_eas_endpoint.py index cc029cffd2c..23329144b72 100644 --- a/libs/langchain/langchain/chat_models/pai_eas_endpoint.py +++ b/libs/langchain/langchain/chat_models/pai_eas_endpoint.py @@ -5,16 +5,9 @@ from functools import partial from typing import Any, AsyncIterator, Dict, List, Optional, cast import requests - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, @@ -22,7 +15,14 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.schema.output import ChatGenerationChunk +from langchain_core.schema.output import ChatGenerationChunk + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/promptlayer_openai.py b/libs/langchain/langchain/chat_models/promptlayer_openai.py index c55e6051c71..9db78029ba7 100644 --- a/libs/langchain/langchain/chat_models/promptlayer_openai.py +++ b/libs/langchain/langchain/chat_models/promptlayer_openai.py @@ -2,13 +2,14 @@ import datetime from typing import Any, Dict, List, Optional +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import BaseMessage + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models import ChatOpenAI -from langchain.schema import ChatResult -from langchain.schema.messages import BaseMessage class PromptLayerChatOpenAI(ChatOpenAI): diff --git a/libs/langchain/langchain/chat_models/tongyi.py b/libs/langchain/langchain/chat_models/tongyi.py index 9176a1ae20b..c079349840c 100644 --- a/libs/langchain/langchain/chat_models/tongyi.py +++ b/libs/langchain/langchain/chat_models/tongyi.py @@ -13,6 +13,23 @@ from typing import ( Type, ) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk from requests.exceptions import HTTPError from tenacity import ( RetryCallState, @@ -27,23 +44,6 @@ from langchain.chat_models.base import ( BaseChatModel, _generate_from_stream, ) -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, -) -from langchain.schema.output import ChatGenerationChunk, GenerationChunk from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 32495be5b81..54077319dd4 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -5,22 +5,23 @@ import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel, _generate_from_stream -from langchain.llms.vertexai import _VertexAICommon, is_codey_model -from langchain.pydantic_v1 import root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) -from langchain.schema.output import ChatGenerationChunk +from langchain_core.schema.output import ChatGenerationChunk + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel, _generate_from_stream +from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.utilities.vertexai import raise_vertex_import_error if TYPE_CHECKING: diff --git a/libs/langchain/langchain/chat_models/yandex.py b/libs/langchain/langchain/chat_models/yandex.py index 0847028aee4..c789ffa0819 100644 --- a/libs/langchain/langchain/chat_models/yandex.py +++ b/libs/langchain/langchain/chat_models/yandex.py @@ -2,6 +2,15 @@ import logging from typing import Any, Dict, List, Optional, Tuple, cast +from langchain_core.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatResult, + HumanMessage, + SystemMessage, +) + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -9,14 +18,6 @@ from langchain.callbacks.manager import ( from langchain.chat_models.base import BaseChatModel from langchain.llms.utils import enforce_stop_tokens from langchain.llms.yandex import _BaseYandexGPT -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatGeneration, - ChatResult, - HumanMessage, - SystemMessage, -) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/docstore/arbitrary_fn.py b/libs/langchain/langchain/docstore/arbitrary_fn.py index 346e9ff6d6f..179a2773c8a 100644 --- a/libs/langchain/langchain/docstore/arbitrary_fn.py +++ b/libs/langchain/langchain/docstore/arbitrary_fn.py @@ -1,7 +1,8 @@ from typing import Callable, Union +from langchain_core.schema import Document + from langchain.docstore.base import Docstore -from langchain.schema import Document class DocstoreFn(Docstore): diff --git a/libs/langchain/langchain/docstore/document.py b/libs/langchain/langchain/docstore/document.py index 1c33318db28..a2825e674d8 100644 --- a/libs/langchain/langchain/docstore/document.py +++ b/libs/langchain/langchain/docstore/document.py @@ -1,3 +1,3 @@ -from langchain.schema import Document +from langchain_core.schema import Document __all__ = ["Document"] diff --git a/libs/langchain/langchain/document_loaders/airbyte.py b/libs/langchain/langchain/document_loaders/airbyte.py index 7f9d4b4407a..0369ad67d28 100644 --- a/libs/langchain/langchain/document_loaders/airbyte.py +++ b/libs/langchain/langchain/document_loaders/airbyte.py @@ -1,8 +1,9 @@ from typing import Any, Callable, Iterator, List, Mapping, Optional +from langchain_core.utils.utils import guard_import + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.utils.utils import guard_import RecordHandler = Callable[[Any, Optional[str]], Document] diff --git a/libs/langchain/langchain/document_loaders/apify_dataset.py b/libs/langchain/langchain/document_loaders/apify_dataset.py index 25658037d6b..2e273966310 100644 --- a/libs/langchain/langchain/document_loaders/apify_dataset.py +++ b/libs/langchain/langchain/document_loaders/apify_dataset.py @@ -1,8 +1,9 @@ from typing import Any, Callable, Dict, List +from langchain_core.pydantic_v1 import BaseModel, root_validator + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import BaseModel, root_validator class ApifyDatasetLoader(BaseLoader, BaseModel): @@ -14,7 +15,7 @@ class ApifyDatasetLoader(BaseLoader, BaseModel): .. code-block:: python from langchain.document_loaders import ApifyDatasetLoader - from langchain.schema import Document + from langchain_core.schema import Document loader = ApifyDatasetLoader( dataset_id="YOUR-DATASET-ID", diff --git a/libs/langchain/langchain/document_loaders/base.py b/libs/langchain/langchain/document_loaders/base.py index 486173ec332..ae036e2c67c 100644 --- a/libs/langchain/langchain/document_loaders/base.py +++ b/libs/langchain/langchain/document_loaders/base.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from typing import Iterator, List, Optional +from langchain_core.schema import Document + from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/libs/langchain/langchain/document_loaders/base_o365.py b/libs/langchain/langchain/document_loaders/base_o365.py index 95df0296a71..b0c85e7379a 100644 --- a/libs/langchain/langchain/document_loaders/base_o365.py +++ b/libs/langchain/langchain/document_loaders/base_o365.py @@ -9,10 +9,17 @@ from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Sequence, Union +from langchain_core.pydantic_v1 import ( + BaseModel, + BaseSettings, + Field, + FilePath, + SecretStr, +) + from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.blob_loaders.file_system import FileSystemBlobLoader from langchain.document_loaders.blob_loaders.schema import Blob -from langchain.pydantic_v1 import BaseModel, BaseSettings, Field, FilePath, SecretStr if TYPE_CHECKING: from O365 import Account diff --git a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py index 59e69c31b21..9d1e737e374 100644 --- a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py +++ b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py @@ -13,7 +13,7 @@ from io import BufferedReader, BytesIO from pathlib import PurePath from typing import Any, Generator, Iterable, Mapping, Optional, Union -from langchain.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, root_validator PathLike = Union[str, PurePath] diff --git a/libs/langchain/langchain/document_loaders/concurrent.py b/libs/langchain/langchain/document_loaders/concurrent.py index bb55d670000..e6f01599c5a 100644 --- a/libs/langchain/langchain/document_loaders/concurrent.py +++ b/libs/langchain/langchain/document_loaders/concurrent.py @@ -4,11 +4,12 @@ import concurrent.futures from pathlib import Path from typing import Iterator, Literal, Optional, Sequence, Union +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader from langchain.document_loaders.generic import GenericLoader from langchain.document_loaders.parsers.registry import get_parser -from langchain.schema import Document _PathLike = Union[str, Path] diff --git a/libs/langchain/langchain/document_loaders/docugami.py b/libs/langchain/langchain/document_loaders/docugami.py index 0b94472a9cc..850f08fa1aa 100644 --- a/libs/langchain/langchain/document_loaders/docugami.py +++ b/libs/langchain/langchain/document_loaders/docugami.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import BaseModel, root_validator TD_NAME = "{http://www.w3.org/1999/xhtml}td" TABLE_NAME = "{http://www.w3.org/1999/xhtml}table" diff --git a/libs/langchain/langchain/document_loaders/dropbox.py b/libs/langchain/langchain/document_loaders/dropbox.py index 513d19a6d43..f7cb81c8968 100644 --- a/libs/langchain/langchain/document_loaders/dropbox.py +++ b/libs/langchain/langchain/document_loaders/dropbox.py @@ -11,9 +11,10 @@ import tempfile from pathlib import Path from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import BaseModel, root_validator + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import BaseModel, root_validator class DropboxLoader(BaseLoader, BaseModel): diff --git a/libs/langchain/langchain/document_loaders/embaas.py b/libs/langchain/langchain/document_loaders/embaas.py index 2b3e6d01a71..8c1025f55f3 100644 --- a/libs/langchain/langchain/document_loaders/embaas.py +++ b/libs/langchain/langchain/document_loaders/embaas.py @@ -3,12 +3,12 @@ import warnings from typing import Any, Dict, Iterator, List, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator, validator from typing_extensions import NotRequired, TypedDict from langchain.docstore.document import Document from langchain.document_loaders.base import BaseBlobParser, BaseLoader from langchain.document_loaders.blob_loaders import Blob -from langchain.pydantic_v1 import BaseModel, root_validator, validator from langchain.text_splitter import TextSplitter from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/document_loaders/generic.py b/libs/langchain/langchain/document_loaders/generic.py index 26d7577a332..f5ee0d2ff67 100644 --- a/libs/langchain/langchain/document_loaders/generic.py +++ b/libs/langchain/langchain/document_loaders/generic.py @@ -3,10 +3,11 @@ from __future__ import annotations from pathlib import Path from typing import Iterator, List, Literal, Optional, Sequence, Union +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser, BaseLoader from langchain.document_loaders.blob_loaders import BlobLoader, FileSystemBlobLoader from langchain.document_loaders.parsers.registry import get_parser -from langchain.schema import Document from langchain.text_splitter import TextSplitter _PathLike = Union[str, Path] diff --git a/libs/langchain/langchain/document_loaders/github.py b/libs/langchain/langchain/document_loaders/github.py index 63cda0f844a..01eb3a0e3d5 100644 --- a/libs/langchain/langchain/document_loaders/github.py +++ b/libs/langchain/langchain/document_loaders/github.py @@ -3,10 +3,10 @@ from datetime import datetime from typing import Dict, Iterator, List, Literal, Optional, Union import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator, validator from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import BaseModel, root_validator, validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/document_loaders/googledrive.py b/libs/langchain/langchain/document_loaders/googledrive.py index 513f9bba773..12b352a6541 100644 --- a/libs/langchain/langchain/document_loaders/googledrive.py +++ b/libs/langchain/langchain/document_loaders/googledrive.py @@ -11,9 +11,10 @@ import os from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union +from langchain_core.pydantic_v1 import BaseModel, root_validator, validator + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import BaseModel, root_validator, validator SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] diff --git a/libs/langchain/langchain/document_loaders/joplin.py b/libs/langchain/langchain/document_loaders/joplin.py index 62efe62fd9c..3f16d25c928 100644 --- a/libs/langchain/langchain/document_loaders/joplin.py +++ b/libs/langchain/langchain/document_loaders/joplin.py @@ -3,8 +3,9 @@ import urllib from datetime import datetime from typing import Iterator, List, Optional +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document from langchain.utils import get_from_env LINK_NOTE_TEMPLATE = "joplin://x-callback-url/openNote?id={id}" diff --git a/libs/langchain/langchain/document_loaders/lakefs.py b/libs/langchain/langchain/document_loaders/lakefs.py index 46460cfebc6..05b0ca9c081 100644 --- a/libs/langchain/langchain/document_loaders/lakefs.py +++ b/libs/langchain/langchain/document_loaders/lakefs.py @@ -5,11 +5,11 @@ from typing import Any, List, Optional from urllib.parse import urljoin import requests +from langchain_core.schema import Document from requests.auth import HTTPBasicAuth from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.unstructured import UnstructuredBaseLoader -from langchain.schema import Document class LakeFSClient: diff --git a/libs/langchain/langchain/document_loaders/onedrive.py b/libs/langchain/langchain/document_loaders/onedrive.py index b79c5dd464c..ab221c5d485 100644 --- a/libs/langchain/langchain/document_loaders/onedrive.py +++ b/libs/langchain/langchain/document_loaders/onedrive.py @@ -4,13 +4,14 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Union +from langchain_core.pydantic_v1 import Field + from langchain.docstore.document import Document from langchain.document_loaders.base_o365 import ( O365BaseLoader, _FileType, ) from langchain.document_loaders.parsers.registry import get_parser -from langchain.pydantic_v1 import Field if TYPE_CHECKING: from O365.drive import Drive, Folder diff --git a/libs/langchain/langchain/document_loaders/onedrive_file.py b/libs/langchain/langchain/document_loaders/onedrive_file.py index 8cc0e7e3753..9c7f4cab171 100644 --- a/libs/langchain/langchain/document_loaders/onedrive_file.py +++ b/libs/langchain/langchain/document_loaders/onedrive_file.py @@ -3,10 +3,11 @@ from __future__ import annotations import tempfile from typing import TYPE_CHECKING, List +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.unstructured import UnstructuredFileLoader -from langchain.pydantic_v1 import BaseModel, Field if TYPE_CHECKING: from O365.drive import File diff --git a/libs/langchain/langchain/document_loaders/parsers/audio.py b/libs/langchain/langchain/document_loaders/parsers/audio.py index 344e48a98ca..720ed899cfb 100644 --- a/libs/langchain/langchain/document_loaders/parsers/audio.py +++ b/libs/langchain/langchain/document_loaders/parsers/audio.py @@ -2,9 +2,10 @@ import logging import time from typing import Dict, Iterator, Optional, Tuple +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/document_loaders/parsers/docai.py b/libs/langchain/langchain/document_loaders/parsers/docai.py index 01c388d6a67..b8b6a7a26be 100644 --- a/libs/langchain/langchain/document_loaders/parsers/docai.py +++ b/libs/langchain/langchain/document_loaders/parsers/docai.py @@ -10,11 +10,12 @@ import time from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence +from langchain_core.utils.iter import batch_iterate + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob from langchain.utilities.vertexai import get_client_info -from langchain.utils.iter import batch_iterate if TYPE_CHECKING: from google.api_core.operation import Operation diff --git a/libs/langchain/langchain/document_loaders/parsers/generic.py b/libs/langchain/langchain/document_loaders/parsers/generic.py index 3d4c0a5ee0b..ff433a540c5 100644 --- a/libs/langchain/langchain/document_loaders/parsers/generic.py +++ b/libs/langchain/langchain/document_loaders/parsers/generic.py @@ -4,9 +4,10 @@ This module contains some logic to help assemble more sophisticated parsers. """ from typing import Iterator, Mapping, Optional +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders.schema import Blob -from langchain.schema import Document class MimeTypeBasedParser(BaseBlobParser): diff --git a/libs/langchain/langchain/document_loaders/parsers/msword.py b/libs/langchain/langchain/document_loaders/parsers/msword.py index 3823a191974..8bced491ed1 100644 --- a/libs/langchain/langchain/document_loaders/parsers/msword.py +++ b/libs/langchain/langchain/document_loaders/parsers/msword.py @@ -1,8 +1,9 @@ from typing import Iterator +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document class MsWordParser(BaseBlobParser): diff --git a/libs/langchain/langchain/document_loaders/parsers/pdf.py b/libs/langchain/langchain/document_loaders/parsers/pdf.py index 606c767714a..722b8b2c79e 100644 --- a/libs/langchain/langchain/document_loaders/parsers/pdf.py +++ b/libs/langchain/langchain/document_loaders/parsers/pdf.py @@ -15,10 +15,10 @@ from typing import ( from urllib.parse import urlparse import numpy as np +from langchain_core.schema import Document from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document if TYPE_CHECKING: import fitz.fitz diff --git a/libs/langchain/langchain/document_loaders/parsers/txt.py b/libs/langchain/langchain/document_loaders/parsers/txt.py index e506c34b464..81ef2071391 100644 --- a/libs/langchain/langchain/document_loaders/parsers/txt.py +++ b/libs/langchain/langchain/document_loaders/parsers/txt.py @@ -1,9 +1,10 @@ """Module for parsing text files..""" from typing import Iterator +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document class TextParser(BaseBlobParser): diff --git a/libs/langchain/langchain/document_loaders/rocksetdb.py b/libs/langchain/langchain/document_loaders/rocksetdb.py index fd3095d23d9..cdfe4313e44 100644 --- a/libs/langchain/langchain/document_loaders/rocksetdb.py +++ b/libs/langchain/langchain/document_loaders/rocksetdb.py @@ -1,7 +1,8 @@ from typing import Any, Callable, Iterator, List, Optional, Tuple +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document def default_joiner(docs: List[Tuple[str, Any]]) -> str: diff --git a/libs/langchain/langchain/document_loaders/sharepoint.py b/libs/langchain/langchain/document_loaders/sharepoint.py index cf27bb0788b..bb5924f9d09 100644 --- a/libs/langchain/langchain/document_loaders/sharepoint.py +++ b/libs/langchain/langchain/document_loaders/sharepoint.py @@ -3,13 +3,14 @@ from __future__ import annotations from typing import Iterator, List, Optional, Sequence +from langchain_core.pydantic_v1 import Field + from langchain.docstore.document import Document from langchain.document_loaders.base_o365 import ( O365BaseLoader, _FileType, ) from langchain.document_loaders.parsers.registry import get_parser -from langchain.pydantic_v1 import Field class SharePointLoader(O365BaseLoader): diff --git a/libs/langchain/langchain/document_loaders/sitemap.py b/libs/langchain/langchain/document_loaders/sitemap.py index d0f0584d7b8..4510d500904 100644 --- a/libs/langchain/langchain/document_loaders/sitemap.py +++ b/libs/langchain/langchain/document_loaders/sitemap.py @@ -3,8 +3,9 @@ import re from typing import Any, Callable, Generator, Iterable, List, Optional, Tuple from urllib.parse import urlparse +from langchain_core.schema import Document + from langchain.document_loaders.web_base import WebBaseLoader -from langchain.schema import Document def _default_parsing_function(content: Any) -> str: diff --git a/libs/langchain/langchain/document_loaders/tensorflow_datasets.py b/libs/langchain/langchain/document_loaders/tensorflow_datasets.py index 82b59d80040..d70fc0da806 100644 --- a/libs/langchain/langchain/document_loaders/tensorflow_datasets.py +++ b/libs/langchain/langchain/document_loaders/tensorflow_datasets.py @@ -1,7 +1,8 @@ from typing import Callable, Dict, Iterator, List, Optional +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document from langchain.utilities.tensorflow_datasets import TensorflowDatasets diff --git a/libs/langchain/langchain/document_loaders/youtube.py b/libs/langchain/langchain/document_loaders/youtube.py index f91adff611b..817c7036bbc 100644 --- a/libs/langchain/langchain/document_loaders/youtube.py +++ b/libs/langchain/langchain/document_loaders/youtube.py @@ -6,10 +6,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union from urllib.parse import parse_qs, urlparse +from langchain_core.pydantic_v1 import root_validator +from langchain_core.pydantic_v1.dataclasses import dataclass + from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.pydantic_v1 import root_validator -from langchain.pydantic_v1.dataclasses import dataclass logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py b/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py index 0ef52ddde32..30f574c8f3e 100644 --- a/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py +++ b/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py @@ -1,6 +1,6 @@ from typing import Any, Iterator, List, Sequence, cast -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document class BeautifulSoupTransformer(BaseDocumentTransformer): diff --git a/libs/langchain/langchain/document_transformers/doctran_text_extract.py b/libs/langchain/langchain/document_transformers/doctran_text_extract.py index 0de109b31d5..7a951dc5abe 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_extract.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_extract.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional, Sequence -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document + from langchain.utils import get_from_env diff --git a/libs/langchain/langchain/document_transformers/doctran_text_qa.py b/libs/langchain/langchain/document_transformers/doctran_text_qa.py index 84f286a6ce8..463d7a93c49 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_qa.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_qa.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Sequence -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document + from langchain.utils import get_from_env diff --git a/libs/langchain/langchain/document_transformers/doctran_text_translate.py b/libs/langchain/langchain/document_transformers/doctran_text_translate.py index f3793cee983..0b685b311e9 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_translate.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_translate.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Sequence -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document + from langchain.utils import get_from_env diff --git a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py index 1ef8175f4c6..e8f29091796 100644 --- a/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py +++ b/libs/langchain/langchain/document_transformers/embeddings_redundant_filter.py @@ -2,10 +2,10 @@ from typing import Any, Callable, List, Sequence import numpy as np +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BaseDocumentTransformer, Document +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BaseDocumentTransformer, Document -from langchain.schema.embeddings import Embeddings from langchain.utils.math import cosine_similarity diff --git a/libs/langchain/langchain/document_transformers/google_translate.py b/libs/langchain/langchain/document_transformers/google_translate.py index f52e618230a..de7f991c089 100644 --- a/libs/langchain/langchain/document_transformers/google_translate.py +++ b/libs/langchain/langchain/document_transformers/google_translate.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Sequence -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document + from langchain.utilities.vertexai import get_client_info diff --git a/libs/langchain/langchain/document_transformers/html2text.py b/libs/langchain/langchain/document_transformers/html2text.py index 3c123950d3c..9ee10408b3e 100644 --- a/libs/langchain/langchain/document_transformers/html2text.py +++ b/libs/langchain/langchain/document_transformers/html2text.py @@ -1,6 +1,6 @@ from typing import Any, Sequence -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.schema import BaseDocumentTransformer, Document class Html2TextTransformer(BaseDocumentTransformer): diff --git a/libs/langchain/langchain/document_transformers/long_context_reorder.py b/libs/langchain/langchain/document_transformers/long_context_reorder.py index 8abcaa9cd72..97c3bae9cde 100644 --- a/libs/langchain/langchain/document_transformers/long_context_reorder.py +++ b/libs/langchain/langchain/document_transformers/long_context_reorder.py @@ -1,8 +1,8 @@ """Reorder documents""" from typing import Any, List, Sequence -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BaseDocumentTransformer, Document +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import BaseDocumentTransformer, Document def _litm_reordering(documents: List[Document]) -> List[Document]: diff --git a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py index 387f33b81d5..2bf4fa85646 100644 --- a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py +++ b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py @@ -3,7 +3,8 @@ import json import uuid from typing import Any, Sequence -from langchain.schema.document import BaseDocumentTransformer, Document +from langchain_core.schema.document import BaseDocumentTransformer, Document + from langchain.tools.nuclia.tool import NucliaUnderstandingAPI diff --git a/libs/langchain/langchain/document_transformers/openai_functions.py b/libs/langchain/langchain/document_transformers/openai_functions.py index 1b85fe75c2c..188583c3c47 100644 --- a/libs/langchain/langchain/document_transformers/openai_functions.py +++ b/libs/langchain/langchain/document_transformers/openai_functions.py @@ -1,12 +1,13 @@ """Document transformers that use OpenAI Functions models""" from typing import Any, Dict, Optional, Sequence, Type, Union +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import BaseDocumentTransformer, Document +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.llm import LLMChain from langchain.chains.openai_functions import create_tagging_chain -from langchain.prompts import ChatPromptTemplate -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BaseDocumentTransformer, Document -from langchain.schema.language_model import BaseLanguageModel class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): @@ -17,7 +18,7 @@ class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): from langchain.chat_models import ChatOpenAI from langchain.document_transformers import OpenAIMetadataTagger - from langchain.schema import Document + from langchain_core.schema import Document schema = { "properties": { @@ -100,7 +101,7 @@ def create_metadata_tagger( from langchain.chat_models import ChatOpenAI from langchain.document_transformers import create_metadata_tagger - from langchain.schema import Document + from langchain_core.schema import Document schema = { "properties": { diff --git a/libs/langchain/langchain/embeddings/aleph_alpha.py b/libs/langchain/langchain/embeddings/aleph_alpha.py index 91c71a6063e..a28a49c38e6 100644 --- a/libs/langchain/langchain/embeddings/aleph_alpha.py +++ b/libs/langchain/langchain/embeddings/aleph_alpha.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/awa.py b/libs/langchain/langchain/embeddings/awa.py index f957cb299d4..11105373363 100644 --- a/libs/langchain/langchain/embeddings/awa.py +++ b/libs/langchain/langchain/embeddings/awa.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings class AwaEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/azure_openai.py b/libs/langchain/langchain/embeddings/azure_openai.py index 9729c2bace4..794d6733aff 100644 --- a/libs/langchain/langchain/embeddings/azure_openai.py +++ b/libs/langchain/langchain/embeddings/azure_openai.py @@ -5,8 +5,9 @@ import os import warnings from typing import Dict, Optional, Union +from langchain_core.pydantic_v1 import Field, root_validator + from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.pydantic_v1 import Field, root_validator from langchain.utils import get_from_dict_or_env from langchain.utils.openai import is_openai_v1 diff --git a/libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py b/libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py index cd5c6990c67..d440cd497a2 100644 --- a/libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py @@ -3,8 +3,9 @@ from __future__ import annotations import logging from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index 7314302d183..60ad5dedbf3 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -1,4 +1,4 @@ -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings # This is for backwards compatibility __all__ = ["Embeddings"] diff --git a/libs/langchain/langchain/embeddings/bedrock.py b/libs/langchain/langchain/embeddings/bedrock.py index 55064f15544..825f14545b5 100644 --- a/libs/langchain/langchain/embeddings/bedrock.py +++ b/libs/langchain/langchain/embeddings/bedrock.py @@ -4,8 +4,8 @@ import os from functools import partial from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings class BedrockEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index ddaa51277a2..621c1d4bc30 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -14,8 +14,9 @@ import uuid from functools import partial from typing import Callable, List, Sequence, Union, cast -from langchain.schema import BaseStore -from langchain.schema.embeddings import Embeddings +from langchain_core.schema import BaseStore +from langchain_core.schema.embeddings import Embeddings + from langchain.storage.encoder_backed import EncoderBackedStore NAMESPACE_UUID = uuid.UUID(int=1985) diff --git a/libs/langchain/langchain/embeddings/clarifai.py b/libs/langchain/langchain/embeddings/clarifai.py index a33c267a57f..9a3ccac7a26 100644 --- a/libs/langchain/langchain/embeddings/clarifai.py +++ b/libs/langchain/langchain/embeddings/clarifai.py @@ -1,8 +1,9 @@ import logging from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index 73435eaa932..dd7b74b5cdb 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/dashscope.py b/libs/langchain/langchain/embeddings/dashscope.py index 31d38e52b15..60a64fc57cd 100644 --- a/libs/langchain/langchain/embeddings/dashscope.py +++ b/libs/langchain/langchain/embeddings/dashscope.py @@ -9,6 +9,8 @@ from typing import ( Optional, ) +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings from requests.exceptions import HTTPError from tenacity import ( before_sleep_log, @@ -18,8 +20,6 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/deepinfra.py b/libs/langchain/langchain/embeddings/deepinfra.py index 369b0445039..20d57e51296 100644 --- a/libs/langchain/langchain/embeddings/deepinfra.py +++ b/libs/langchain/langchain/embeddings/deepinfra.py @@ -1,9 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32" diff --git a/libs/langchain/langchain/embeddings/edenai.py b/libs/langchain/langchain/embeddings/edenai.py index 8a0f717dc96..da5869e0480 100644 --- a/libs/langchain/langchain/embeddings/edenai.py +++ b/libs/langchain/langchain/embeddings/edenai.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/elasticsearch.py b/libs/langchain/langchain/embeddings/elasticsearch.py index 412bac5be96..fb0db428597 100644 --- a/libs/langchain/langchain/embeddings/elasticsearch.py +++ b/libs/langchain/langchain/embeddings/elasticsearch.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from elasticsearch import Elasticsearch from elasticsearch.client import MlClient -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings class ElasticsearchEmbeddings(Embeddings): diff --git a/libs/langchain/langchain/embeddings/embaas.py b/libs/langchain/langchain/embeddings/embaas.py index 530206c73ef..3f054764fa1 100644 --- a/libs/langchain/langchain/embeddings/embaas.py +++ b/libs/langchain/langchain/embeddings/embaas.py @@ -1,10 +1,10 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings from typing_extensions import NotRequired, TypedDict -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env # Currently supported maximum batch size for embedding requests diff --git a/libs/langchain/langchain/embeddings/ernie.py b/libs/langchain/langchain/embeddings/ernie.py index c69d7d9fc69..9c147b00b4d 100644 --- a/libs/langchain/langchain/embeddings/ernie.py +++ b/libs/langchain/langchain/embeddings/ernie.py @@ -5,9 +5,9 @@ from functools import partial from typing import Dict, List, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/fake.py b/libs/langchain/langchain/embeddings/fake.py index 0e04879b463..649aa93c367 100644 --- a/libs/langchain/langchain/embeddings/fake.py +++ b/libs/langchain/langchain/embeddings/fake.py @@ -2,9 +2,8 @@ import hashlib from typing import List import numpy as np - -from langchain.pydantic_v1 import BaseModel -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.embeddings import Embeddings class FakeEmbeddings(Embeddings, BaseModel): diff --git a/libs/langchain/langchain/embeddings/fastembed.py b/libs/langchain/langchain/embeddings/fastembed.py index cbc2c9ff16b..ed97526fdef 100644 --- a/libs/langchain/langchain/embeddings/fastembed.py +++ b/libs/langchain/langchain/embeddings/fastembed.py @@ -1,9 +1,8 @@ from typing import Any, Dict, List, Literal, Optional import numpy as np - -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings class FastEmbedEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/google_palm.py b/libs/langchain/langchain/embeddings/google_palm.py index fcf83b36692..afb38763a8e 100644 --- a/libs/langchain/langchain/embeddings/google_palm.py +++ b/libs/langchain/langchain/embeddings/google_palm.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings from tenacity import ( before_sleep_log, retry, @@ -11,8 +13,6 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/gpt4all.py b/libs/langchain/langchain/embeddings/gpt4all.py index 318cef54152..e0572f60cad 100644 --- a/libs/langchain/langchain/embeddings/gpt4all.py +++ b/libs/langchain/langchain/embeddings/gpt4all.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings class GPT4AllEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/gradient_ai.py b/libs/langchain/langchain/embeddings/gradient_ai.py index 5ff429af76d..290b0d1219e 100644 --- a/libs/langchain/langchain/embeddings/gradient_ai.py +++ b/libs/langchain/langchain/embeddings/gradient_ai.py @@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import aiohttp import numpy as np import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env __all__ = ["GradientEmbeddings"] diff --git a/libs/langchain/langchain/embeddings/huggingface.py b/libs/langchain/langchain/embeddings/huggingface.py index 5f355d0e2bd..81ce689d8bc 100644 --- a/libs/langchain/langchain/embeddings/huggingface.py +++ b/libs/langchain/langchain/embeddings/huggingface.py @@ -1,9 +1,8 @@ from typing import Any, Dict, List, Optional import requests - -from langchain.pydantic_v1 import BaseModel, Extra, Field -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.schema.embeddings import Embeddings DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" diff --git a/libs/langchain/langchain/embeddings/huggingface_hub.py b/libs/langchain/langchain/embeddings/huggingface_hub.py index d887be2abc0..2c14614eea7 100644 --- a/libs/langchain/langchain/embeddings/huggingface_hub.py +++ b/libs/langchain/langchain/embeddings/huggingface_hub.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2" diff --git a/libs/langchain/langchain/embeddings/javelin_ai_gateway.py b/libs/langchain/langchain/embeddings/javelin_ai_gateway.py index db97b183c7e..871b1838b10 100644 --- a/libs/langchain/langchain/embeddings/javelin_ai_gateway.py +++ b/libs/langchain/langchain/embeddings/javelin_ai_gateway.py @@ -2,8 +2,8 @@ from __future__ import annotations from typing import Any, Iterator, List, Optional -from langchain.pydantic_v1 import BaseModel -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.embeddings import Embeddings def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: diff --git a/libs/langchain/langchain/embeddings/jina.py b/libs/langchain/langchain/embeddings/jina.py index 9008e2a5ce6..c94728cb4f5 100644 --- a/libs/langchain/langchain/embeddings/jina.py +++ b/libs/langchain/langchain/embeddings/jina.py @@ -2,9 +2,9 @@ import os from typing import Any, Dict, List, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/johnsnowlabs.py b/libs/langchain/langchain/embeddings/johnsnowlabs.py index 494b66bfbd2..f14924707b9 100644 --- a/libs/langchain/langchain/embeddings/johnsnowlabs.py +++ b/libs/langchain/langchain/embeddings/johnsnowlabs.py @@ -2,8 +2,9 @@ import os import sys from typing import Any, List +from langchain_core.pydantic_v1 import BaseModel, Extra + from langchain.embeddings.base import Embeddings -from langchain.pydantic_v1 import BaseModel, Extra class JohnSnowLabsEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/llamacpp.py b/libs/langchain/langchain/embeddings/llamacpp.py index 5da4999132d..b2dbda8a42c 100644 --- a/libs/langchain/langchain/embeddings/llamacpp.py +++ b/libs/langchain/langchain/embeddings/llamacpp.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.embeddings import Embeddings class LlamaCppEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/llm_rails.py b/libs/langchain/langchain/embeddings/llm_rails.py index 60312384afc..804f8224aa7 100644 --- a/libs/langchain/langchain/embeddings/llm_rails.py +++ b/libs/langchain/langchain/embeddings/llm_rails.py @@ -4,9 +4,8 @@ import os from typing import List, Optional import requests - -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema.embeddings import Embeddings class LLMRailsEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/localai.py b/libs/langchain/langchain/embeddings/localai.py index 53f51d55b80..0f0e9ac1c92 100644 --- a/libs/langchain/langchain/embeddings/localai.py +++ b/libs/langchain/langchain/embeddings/localai.py @@ -15,6 +15,9 @@ from typing import ( Union, ) +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.embeddings import Embeddings +from langchain_core.utils import get_pydantic_field_names from tenacity import ( AsyncRetrying, before_sleep_log, @@ -24,9 +27,7 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.embeddings import Embeddings -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/minimax.py b/libs/langchain/langchain/embeddings/minimax.py index 9b8035d904f..3bc2840cf13 100644 --- a/libs/langchain/langchain/embeddings/minimax.py +++ b/libs/langchain/langchain/embeddings/minimax.py @@ -4,6 +4,8 @@ import logging from typing import Any, Callable, Dict, List, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings from tenacity import ( before_sleep_log, retry, @@ -11,8 +13,6 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/mlflow_gateway.py b/libs/langchain/langchain/embeddings/mlflow_gateway.py index ad03ef30f24..7375bd71d35 100644 --- a/libs/langchain/langchain/embeddings/mlflow_gateway.py +++ b/libs/langchain/langchain/embeddings/mlflow_gateway.py @@ -2,8 +2,8 @@ from __future__ import annotations from typing import Any, Iterator, List, Optional -from langchain.pydantic_v1 import BaseModel -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.embeddings import Embeddings def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: diff --git a/libs/langchain/langchain/embeddings/modelscope_hub.py b/libs/langchain/langchain/embeddings/modelscope_hub.py index 4dd27b6da61..23e72da5ab8 100644 --- a/libs/langchain/langchain/embeddings/modelscope_hub.py +++ b/libs/langchain/langchain/embeddings/modelscope_hub.py @@ -1,7 +1,7 @@ from typing import Any, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema.embeddings import Embeddings class ModelScopeEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/mosaicml.py b/libs/langchain/langchain/embeddings/mosaicml.py index 934ffb83901..72f8c341d1c 100644 --- a/libs/langchain/langchain/embeddings/mosaicml.py +++ b/libs/langchain/langchain/embeddings/mosaicml.py @@ -1,9 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional, Tuple import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/nlpcloud.py b/libs/langchain/langchain/embeddings/nlpcloud.py index 1b698c95861..38b44f7975e 100644 --- a/libs/langchain/langchain/embeddings/nlpcloud.py +++ b/libs/langchain/langchain/embeddings/nlpcloud.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/embeddings/octoai_embeddings.py b/libs/langchain/langchain/embeddings/octoai_embeddings.py index 0a39bb790e5..286173054fb 100644 --- a/libs/langchain/langchain/embeddings/octoai_embeddings.py +++ b/libs/langchain/langchain/embeddings/octoai_embeddings.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Mapping, Optional -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.utils import get_from_dict_or_env DEFAULT_EMBED_INSTRUCTION = "Represent this input: " diff --git a/libs/langchain/langchain/embeddings/ollama.py b/libs/langchain/langchain/embeddings/ollama.py index eb1cabbb833..d254677361e 100644 --- a/libs/langchain/langchain/embeddings/ollama.py +++ b/libs/langchain/langchain/embeddings/ollama.py @@ -1,9 +1,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests - -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema.embeddings import Embeddings class OllamaEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/openai.py b/libs/langchain/langchain/embeddings/openai.py index 7cb69921868..60d55b1c188 100644 --- a/libs/langchain/langchain/embeddings/openai.py +++ b/libs/langchain/langchain/embeddings/openai.py @@ -20,6 +20,9 @@ from typing import ( ) import numpy as np +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema.embeddings import Embeddings +from langchain_core.utils import get_pydantic_field_names from packaging.version import Version, parse from tenacity import ( AsyncRetrying, @@ -30,9 +33,7 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.embeddings import Embeddings -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py index 0e724624ae8..d1f4fe775f1 100644 --- a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py +++ b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.llms.sagemaker_endpoint import ContentHandlerBase -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]): diff --git a/libs/langchain/langchain/embeddings/self_hosted.py b/libs/langchain/langchain/embeddings/self_hosted.py index 3b223b4aa26..5889999160f 100644 --- a/libs/langchain/langchain/embeddings/self_hosted.py +++ b/libs/langchain/langchain/embeddings/self_hosted.py @@ -1,8 +1,9 @@ from typing import Any, Callable, List +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema.embeddings import Embeddings + from langchain.llms.self_hosted import SelfHostedPipeline -from langchain.pydantic_v1 import Extra -from langchain.schema.embeddings import Embeddings def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[float]]: diff --git a/libs/langchain/langchain/embeddings/spacy_embeddings.py b/libs/langchain/langchain/embeddings/spacy_embeddings.py index f5cd45f19d3..460cba90c31 100644 --- a/libs/langchain/langchain/embeddings/spacy_embeddings.py +++ b/libs/langchain/langchain/embeddings/spacy_embeddings.py @@ -1,8 +1,8 @@ import importlib.util from typing import Any, Dict, List -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.embeddings import Embeddings class SpacyEmbeddings(BaseModel, Embeddings): diff --git a/libs/langchain/langchain/embeddings/tensorflow_hub.py b/libs/langchain/langchain/embeddings/tensorflow_hub.py index 0bb9bcdc801..918bcd0d412 100644 --- a/libs/langchain/langchain/embeddings/tensorflow_hub.py +++ b/libs/langchain/langchain/embeddings/tensorflow_hub.py @@ -1,7 +1,7 @@ from typing import Any, List -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.schema.embeddings import Embeddings DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" diff --git a/libs/langchain/langchain/embeddings/vertexai.py b/libs/langchain/langchain/embeddings/vertexai.py index 16ac1f36cbb..de1206fb7ae 100644 --- a/libs/langchain/langchain/embeddings/vertexai.py +++ b/libs/langchain/langchain/embeddings/vertexai.py @@ -1,8 +1,9 @@ from typing import Dict, List +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema.embeddings import Embeddings + from langchain.llms.vertexai import _VertexAICommon -from langchain.pydantic_v1 import root_validator -from langchain.schema.embeddings import Embeddings from langchain.utilities.vertexai import raise_vertex_import_error diff --git a/libs/langchain/langchain/embeddings/voyageai.py b/libs/langchain/langchain/embeddings/voyageai.py index e02f3de6c10..3f07a67bd30 100644 --- a/libs/langchain/langchain/embeddings/voyageai.py +++ b/libs/langchain/langchain/embeddings/voyageai.py @@ -14,6 +14,9 @@ from typing import ( ) import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.schema.embeddings import Embeddings +from langchain_core.utils import convert_to_secret_str from tenacity import ( before_sleep_log, retry, @@ -21,9 +24,7 @@ from tenacity import ( wait_exponential, ) -from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain.schema.embeddings import Embeddings -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/xinference.py b/libs/langchain/langchain/embeddings/xinference.py index 62ab74d2aad..d56bc622d00 100644 --- a/libs/langchain/langchain/embeddings/xinference.py +++ b/libs/langchain/langchain/embeddings/xinference.py @@ -1,7 +1,7 @@ """Wrapper around Xinference embedding models.""" from typing import Any, List, Optional -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings class XinferenceEmbeddings(Embeddings): diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index 805a968d01e..3ae5d21d42d 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -18,6 +18,10 @@ from typing import ( cast, ) +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.schema import AgentAction, BaseOutputParser, OutputParserException +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -30,9 +34,6 @@ from langchain.evaluation.agents.trajectory_eval_prompt import ( TOOL_FREE_EVAL_CHAT_PROMPT, ) from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain -from langchain.pydantic_v1 import Extra, Field -from langchain.schema import AgentAction, BaseOutputParser, OutputParserException -from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py index ceebc72ef17..9037a64aeee 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py @@ -1,8 +1,8 @@ """Prompt for trajectory evaluation chain.""" # flake8: noqa -from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage +from langchain_core.schema.messages import HumanMessage, AIMessage, SystemMessage -from langchain.prompts.chat import ( +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index 451f4eeb70c..1f06af3851e 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -5,6 +5,11 @@ import logging import re from typing import Any, Dict, List, Optional, Union +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.schema import RUN_KEY, BaseOutputParser +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -20,10 +25,6 @@ from langchain.evaluation.criteria.eval_chain import ( Criteria, ) from langchain.evaluation.schema import LLMEvalChain, PairwiseStringEvaluator -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, Field -from langchain.schema import RUN_KEY, BaseOutputParser -from langchain.schema.language_model import BaseLanguageModel logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/evaluation/comparison/prompt.py b/libs/langchain/langchain/evaluation/comparison/prompt.py index bed64a4dc63..c44f389ab96 100644 --- a/libs/langchain/langchain/evaluation/comparison/prompt.py +++ b/libs/langchain/langchain/evaluation/comparison/prompt.py @@ -5,7 +5,7 @@ and answers the question. The prompt is based on the paper from Zheng, et. al. https://arxiv.org/abs/2306.05685 """ # flake8: noqa -from langchain.prompts.chat import ChatPromptTemplate +from langchain_core.prompts.chat import ChatPromptTemplate SYSTEM_MESSAGE = 'Please act as an impartial judge and evaluate the quality \ of the responses provided by two AI assistants to the user question displayed below. \ diff --git a/libs/langchain/langchain/evaluation/criteria/eval_chain.py b/libs/langchain/langchain/evaluation/criteria/eval_chain.py index adfc3f23781..f2e0476f79d 100644 --- a/libs/langchain/langchain/evaluation/criteria/eval_chain.py +++ b/libs/langchain/langchain/evaluation/criteria/eval_chain.py @@ -4,14 +4,15 @@ import re from enum import Enum from typing import Any, Dict, List, Mapping, Optional, Union +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.schema import RUN_KEY, BaseOutputParser, BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES from langchain.evaluation.schema import LLMEvalChain, StringEvaluator -from langchain.pydantic_v1 import Extra, Field -from langchain.schema import RUN_KEY, BaseOutputParser, BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel class Criteria(str, Enum): diff --git a/libs/langchain/langchain/evaluation/criteria/prompt.py b/libs/langchain/langchain/evaluation/criteria/prompt.py index ab2c4a67dc9..e5ac19fe038 100644 --- a/libs/langchain/langchain/evaluation/criteria/prompt.py +++ b/libs/langchain/langchain/evaluation/criteria/prompt.py @@ -1,7 +1,7 @@ # flake8: noqa # Credit to https://github.com/openai/evals/tree/main -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate template = """You are assessing a submitted answer on a given task or input based on a set of criteria. Here is the data: [BEGIN DATA] diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index f4a88b2cd49..182246085dd 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -3,6 +3,9 @@ from enum import Enum from typing import Any, Dict, List, Optional import numpy as np +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import RUN_KEY +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -12,9 +15,6 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.embeddings.openai import OpenAIEmbeddings from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import RUN_KEY -from langchain.schema.embeddings import Embeddings from langchain.utils.math import cosine_similarity diff --git a/libs/langchain/langchain/evaluation/loading.py b/libs/langchain/langchain/evaluation/loading.py index 7579cb02ef1..b2d8b63e444 100644 --- a/libs/langchain/langchain/evaluation/loading.py +++ b/libs/langchain/langchain/evaluation/loading.py @@ -1,6 +1,8 @@ """Loading datasets and evaluators.""" from typing import Any, Dict, List, Optional, Sequence, Type, Union +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.base import Chain from langchain.chat_models.openai import ChatOpenAI from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain @@ -32,7 +34,6 @@ from langchain.evaluation.string_distance.base import ( PairwiseStringDistanceEvalChain, StringDistanceEvalChain, ) -from langchain.schema.language_model import BaseLanguageModel def load_dataset(uri: str) -> List[Dict]: diff --git a/libs/langchain/langchain/evaluation/qa/eval_chain.py b/libs/langchain/langchain/evaluation/qa/eval_chain.py index 60264aa6b91..369c976863b 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_chain.py +++ b/libs/langchain/langchain/evaluation/qa/eval_chain.py @@ -5,14 +5,15 @@ import re import string from typing import Any, List, Optional, Sequence, Tuple +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema import RUN_KEY +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT from langchain.evaluation.schema import LLMEvalChain, StringEvaluator -from langchain.prompts import PromptTemplate -from langchain.pydantic_v1 import Extra -from langchain.schema import RUN_KEY -from langchain.schema.language_model import BaseLanguageModel def _get_score(text: str) -> Optional[Tuple[str, int]]: diff --git a/libs/langchain/langchain/evaluation/qa/eval_prompt.py b/libs/langchain/langchain/evaluation/qa/eval_prompt.py index 6675732c641..d29a7858ac2 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_prompt.py +++ b/libs/langchain/langchain/evaluation/qa/eval_prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate template = """You are a teacher grading a quiz. You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT. diff --git a/libs/langchain/langchain/evaluation/qa/generate_chain.py b/libs/langchain/langchain/evaluation/qa/generate_chain.py index 5f925e2e35b..90588fee2e5 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_chain.py +++ b/libs/langchain/langchain/evaluation/qa/generate_chain.py @@ -3,12 +3,13 @@ from __future__ import annotations from typing import Any +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.output_parser import BaseLLMOutputParser + from langchain.chains.llm import LLMChain from langchain.evaluation.qa.generate_prompt import PROMPT from langchain.output_parsers.regex import RegexParser -from langchain.pydantic_v1 import Field -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.output_parser import BaseLLMOutputParser _QA_OUTPUT_PARSER = RegexParser( regex=r"QUESTION: (.*?)\n+ANSWER: (.*)", output_keys=["query", "answer"] diff --git a/libs/langchain/langchain/evaluation/qa/generate_prompt.py b/libs/langchain/langchain/evaluation/qa/generate_prompt.py index aae2845f6e3..50dc318b72d 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_prompt.py +++ b/libs/langchain/langchain/evaluation/qa/generate_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa from langchain.output_parsers.regex import RegexParser -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate template = """You are a teacher coming up with questions to ask on a quiz. Given the following document, please generate a question and answer based on that document. diff --git a/libs/langchain/langchain/evaluation/schema.py b/libs/langchain/langchain/evaluation/schema.py index 86fdbaf0165..95140a2ef1b 100644 --- a/libs/langchain/langchain/evaluation/schema.py +++ b/libs/langchain/langchain/evaluation/schema.py @@ -9,9 +9,10 @@ from functools import partial from typing import Any, Optional, Sequence, Tuple, Union from warnings import warn +from langchain_core.schema.agent import AgentAction +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.chains.base import Chain -from langchain.schema.agent import AgentAction -from langchain.schema.language_model import BaseLanguageModel logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index 1fa10845f22..ecf00495d82 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -5,6 +5,11 @@ import logging import re from typing import Any, Dict, List, Optional, Union +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.schema import RUN_KEY, BaseOutputParser +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain @@ -21,10 +26,6 @@ from langchain.evaluation.scoring.prompt import ( SCORING_TEMPLATE, SCORING_TEMPLATE_WITH_REFERENCE, ) -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, Field -from langchain.schema import RUN_KEY, BaseOutputParser -from langchain.schema.language_model import BaseLanguageModel logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/evaluation/scoring/prompt.py b/libs/langchain/langchain/evaluation/scoring/prompt.py index 1d25055834c..99f899824dc 100644 --- a/libs/langchain/langchain/evaluation/scoring/prompt.py +++ b/libs/langchain/langchain/evaluation/scoring/prompt.py @@ -5,7 +5,7 @@ and answers the question. The prompt is based on the paper from Zheng, et. al. https://arxiv.org/abs/2306.05685 """ # flake8: noqa -from langchain.prompts.chat import ChatPromptTemplate +from langchain_core.prompts.chat import ChatPromptTemplate SYSTEM_MESSAGE = "You are a helpful assistant." diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index 81aa9a9ebe0..07940c644d0 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -3,6 +3,9 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import RUN_KEY + from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -10,8 +13,6 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import RUN_KEY def _load_rapidfuzz() -> Any: diff --git a/libs/langchain/langchain/formatting.py b/libs/langchain/langchain/formatting.py index ebb865c957d..26193f46676 100644 --- a/libs/langchain/langchain/formatting.py +++ b/libs/langchain/langchain/formatting.py @@ -1,4 +1,4 @@ """DEPRECATED: Kept for backwards compatibility.""" -from langchain.utils.formatting import StrictFormatter, formatter +from langchain_core.utils.formatting import StrictFormatter, formatter __all__ = ["StrictFormatter", "formatter"] diff --git a/libs/langchain/langchain/globals/__init__.py b/libs/langchain/langchain/globals/__init__.py index 6a26c2f0ca2..883ddb7ca47 100644 --- a/libs/langchain/langchain/globals/__init__.py +++ b/libs/langchain/langchain/globals/__init__.py @@ -3,7 +3,7 @@ import warnings from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from langchain.schema import BaseCache + from langchain_core.schema import BaseCache # DO NOT USE THESE VALUES DIRECTLY! diff --git a/libs/langchain/langchain/graphs/graph_document.py b/libs/langchain/langchain/graphs/graph_document.py index 9f72a3ad8e0..00625f1d75f 100644 --- a/libs/langchain/langchain/graphs/graph_document.py +++ b/libs/langchain/langchain/graphs/graph_document.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import List, Union -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import Field -from langchain.schema import Document +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import Document class Node(Serializable): diff --git a/libs/langchain/langchain/hub.py b/libs/langchain/langchain/hub.py index ebacbad6c8e..25c5598a169 100644 --- a/libs/langchain/langchain/hub.py +++ b/libs/langchain/langchain/hub.py @@ -3,8 +3,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional -from langchain.load.dump import dumps -from langchain.load.load import loads +from langchain_core.load.dump import dumps +from langchain_core.load.load import loads if TYPE_CHECKING: from langchainhub import Client diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index a386e656341..24681e44003 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -24,11 +24,12 @@ from typing import ( cast, ) +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import Document +from langchain_core.schema.vectorstore import VectorStore + from langchain.document_loaders.base import BaseLoader from langchain.indexes.base import NAMESPACE_UUID, RecordManager -from langchain.pydantic_v1 import root_validator -from langchain.schema import Document -from langchain.schema.vectorstore import VectorStore T = TypeVar("T") diff --git a/libs/langchain/langchain/indexes/graph.py b/libs/langchain/langchain/indexes/graph.py index 9772a9ef2b0..7fe48520073 100644 --- a/libs/langchain/langchain/indexes/graph.py +++ b/libs/langchain/langchain/indexes/graph.py @@ -1,14 +1,15 @@ """Graph Index Creator.""" from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.prompt_template import BasePromptTemplate + from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.indexes.prompts.knowledge_triplet_extraction import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, ) -from langchain.pydantic_v1 import BaseModel -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.prompt_template import BasePromptTemplate class GraphIndexCreator(BaseModel): diff --git a/libs/langchain/langchain/indexes/prompts/entity_extraction.py b/libs/langchain/langchain/indexes/prompts/entity_extraction.py index 47cc349cb2b..416ba13eea9 100644 --- a/libs/langchain/langchain/indexes/prompts/entity_extraction.py +++ b/libs/langchain/langchain/indexes/prompts/entity_extraction.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places. diff --git a/libs/langchain/langchain/indexes/prompts/entity_summarization.py b/libs/langchain/langchain/indexes/prompts/entity_summarization.py index 41e97f5f62d..aa8ec6ef99f 100644 --- a/libs/langchain/langchain/indexes/prompts/entity_summarization.py +++ b/libs/langchain/langchain/indexes/prompts/entity_summarization.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence. The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity. diff --git a/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py b/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py index 0505965c098..70b6d4e2755 100644 --- a/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py +++ b/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py @@ -1,7 +1,7 @@ # flake8: noqa from langchain.graphs.networkx_graph import KG_TRIPLE_DELIMITER -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = ( "You are a networked intelligence helping a human track knowledge triples" diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index 8e32dc18fdc..940dceef373 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -1,15 +1,16 @@ from typing import Any, Dict, List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.vectorstore import VectorStore + from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.document_loaders.base import BaseLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.openai import OpenAI -from langchain.pydantic_v1 import BaseModel, Extra, Field -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.vectorstores.chroma import Chroma diff --git a/libs/langchain/langchain/input.py b/libs/langchain/langchain/input.py index 7fa443ef452..fdd91412d6e 100644 --- a/libs/langchain/langchain/input.py +++ b/libs/langchain/langchain/input.py @@ -1,5 +1,5 @@ """DEPRECATED: Kept for backwards compatibility.""" -from langchain.utils.input import ( +from langchain_core.utils.input import ( get_bolded_text, get_color_mapping, get_colored_text, diff --git a/libs/langchain/langchain/llms/ai21.py b/libs/langchain/langchain/llms/ai21.py index cec49832245..616c8a36495 100644 --- a/libs/langchain/langchain/llms/ai21.py +++ b/libs/langchain/langchain/llms/ai21.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List, Optional, cast import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class AI21PenaltyData(BaseModel): diff --git a/libs/langchain/langchain/llms/aleph_alpha.py b/libs/langchain/langchain/llms/aleph_alpha.py index e73545756ed..2ed1de43622 100644 --- a/libs/langchain/langchain/llms/aleph_alpha.py +++ b/libs/langchain/langchain/llms/aleph_alpha.py @@ -1,10 +1,12 @@ from typing import Any, Dict, List, Optional, Sequence +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.utils import convert_to_secret_str + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class AlephAlpha(LLM): diff --git a/libs/langchain/langchain/llms/amazon_api_gateway.py b/libs/langchain/langchain/llms/amazon_api_gateway.py index 1a019dfdd92..d40658a0588 100644 --- a/libs/langchain/langchain/llms/amazon_api_gateway.py +++ b/libs/langchain/langchain/llms/amazon_api_gateway.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra class ContentHandlerAmazonAPIGateway: diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index e423f5fd5b8..b93edd26476 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -11,21 +11,22 @@ from typing import ( Optional, ) +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.output import GenerationChunk +from langchain_core.schema.prompt import PromptValue +from langchain_core.utils import ( + check_package_version, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.output import GenerationChunk -from langchain.schema.prompt import PromptValue -from langchain.utils import ( - check_package_version, - get_from_dict_or_env, - get_pydantic_field_names, -) -from langchain.utils.utils import build_extra_kwargs, convert_to_secret_str +from langchain.utils import get_from_dict_or_env class _AnthropicCommon(BaseLanguageModel): diff --git a/libs/langchain/langchain/llms/anyscale.py b/libs/langchain/langchain/llms/anyscale.py index 50349af8356..2ffeab40987 100644 --- a/libs/langchain/langchain/llms/anyscale.py +++ b/libs/langchain/langchain/llms/anyscale.py @@ -12,6 +12,11 @@ from typing import ( cast, ) +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema import Generation, LLMResult +from langchain_core.schema.output import GenerationChunk +from langchain_core.utils import convert_to_secret_str + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -21,10 +26,7 @@ from langchain.llms.openai import ( acompletion_with_retry, completion_with_retry, ) -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema import Generation, LLMResult -from langchain.schema.output import GenerationChunk -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env def update_token_usage( diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 469b5e250fd..72028097c2a 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, root_validator from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/aviary.py b/libs/langchain/langchain/llms/aviary.py index 8444a3a041c..5eb4b4a53b1 100644 --- a/libs/langchain/langchain/llms/aviary.py +++ b/libs/langchain/langchain/llms/aviary.py @@ -3,11 +3,11 @@ import os from typing import Any, Dict, List, Mapping, Optional, Union, cast import requests +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env TIMEOUT = 60 diff --git a/libs/langchain/langchain/llms/azureml_endpoint.py b/libs/langchain/langchain/llms/azureml_endpoint.py index e33e1c82b73..a8e82cbdbce 100644 --- a/libs/langchain/langchain/llms/azureml_endpoint.py +++ b/libs/langchain/langchain/llms/azureml_endpoint.py @@ -4,9 +4,10 @@ import warnings from abc import abstractmethod from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import BaseModel, validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import BaseModel, validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py b/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py index 53b79085e56..69f5538739d 100644 --- a/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py @@ -10,13 +10,14 @@ from typing import ( Optional, ) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.output import GenerationChunk + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/bananadev.py b/libs/langchain/langchain/llms/bananadev.py index 3a984a3cb2f..6c1aec08500 100644 --- a/libs/langchain/langchain/llms/bananadev.py +++ b/libs/langchain/langchain/llms/bananadev.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 04fe74b9db6..a6ffab2a186 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -1,1077 +1,15 @@ -"""Base interface for large language models to expose.""" -from __future__ import annotations - -import asyncio -import functools -import inspect -import json -import logging -import warnings -from abc import ABC, abstractmethod -from functools import partial -from pathlib import Path -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, +from langchain_core.llm import ( + LLM, + BaseLLM, + create_base_retry_decorator, + get_prompts, + update_cache, ) -import yaml -from tenacity import ( - RetryCallState, - before_sleep_log, - retry, - retry_base, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForLLMRun, - CallbackManager, - CallbackManagerForLLMRun, - Callbacks, -) -from langchain.globals import get_llm_cache -from langchain.load.dump import dumpd -from langchain.prompts.base import StringPromptValue -from langchain.prompts.chat import ChatPromptValue -from langchain.pydantic_v1 import Field, root_validator, validator -from langchain.schema import Generation, LLMResult, PromptValue, RunInfo -from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput -from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string -from langchain.schema.output import GenerationChunk -from langchain.schema.runnable import RunnableConfig -from langchain.schema.runnable.config import get_config_list - -logger = logging.getLogger(__name__) - - -def _get_verbosity() -> bool: - from langchain.globals import get_verbose - - return get_verbose() - - -@functools.lru_cache -def _log_error_once(msg: str) -> None: - """Log an error once.""" - logger.error(msg) - - -def create_base_retry_decorator( - error_types: List[Type[BaseException]], - max_retries: int = 1, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, -) -> Callable[[Any], Any]: - """Create a retry decorator for a given LLM and provided list of error types.""" - - _logging = before_sleep_log(logger, logging.WARNING) - - def _before_sleep(retry_state: RetryCallState) -> None: - _logging(retry_state) - if run_manager: - if isinstance(run_manager, AsyncCallbackManagerForLLMRun): - coro = run_manager.on_retry(retry_state) - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(coro) - else: - asyncio.run(coro) - except Exception as e: - _log_error_once(f"Error in on_retry: {e}") - else: - run_manager.on_retry(retry_state) - return None - - min_seconds = 4 - max_seconds = 10 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - retry_instance: "retry_base" = retry_if_exception_type(error_types[0]) - for error in error_types[1:]: - retry_instance = retry_instance | retry_if_exception_type(error) - return retry( - reraise=True, - stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=retry_instance, - before_sleep=_before_sleep, - ) - - -def get_prompts( - params: Dict[str, Any], prompts: List[str] -) -> Tuple[Dict[int, List], str, List[int], List[str]]: - """Get prompts that are already cached.""" - llm_string = str(sorted([(k, v) for k, v in params.items()])) - missing_prompts = [] - missing_prompt_idxs = [] - existing_prompts = {} - llm_cache = get_llm_cache() - for i, prompt in enumerate(prompts): - if llm_cache is not None: - cache_val = llm_cache.lookup(prompt, llm_string) - if isinstance(cache_val, list): - existing_prompts[i] = cache_val - else: - missing_prompts.append(prompt) - missing_prompt_idxs.append(i) - return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts - - -def update_cache( - existing_prompts: Dict[int, List], - llm_string: str, - missing_prompt_idxs: List[int], - new_results: LLMResult, - prompts: List[str], -) -> Optional[dict]: - """Update the cache and get the LLM output.""" - llm_cache = get_llm_cache() - for i, result in enumerate(new_results.generations): - existing_prompts[missing_prompt_idxs[i]] = result - prompt = prompts[missing_prompt_idxs[i]] - if llm_cache is not None: - llm_cache.update(prompt, llm_string, result) - llm_output = new_results.llm_output - return llm_output - - -class BaseLLM(BaseLanguageModel[str], ABC): - """Base LLM abstract interface. - - It should take in a prompt and return a string.""" - - cache: Optional[bool] = None - verbose: bool = Field(default_factory=_get_verbosity) - """Whether to print out response text.""" - callbacks: Callbacks = Field(default=None, exclude=True) - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) - tags: Optional[List[str]] = Field(default=None, exclude=True) - """Tags to add to the run trace.""" - metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) - """Metadata to add to the run trace.""" - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator() - def raise_deprecation(cls, values: Dict) -> Dict: - """Raise deprecation warning if callback_manager is used.""" - if values.get("callback_manager") is not None: - warnings.warn( - "callback_manager is deprecated. Please use callbacks instead.", - DeprecationWarning, - ) - values["callbacks"] = values.pop("callback_manager", None) - return values - - @validator("verbose", pre=True, always=True) - def set_verbose(cls, verbose: Optional[bool]) -> bool: - """If verbose is None, set it. - - This allows users to pass in None as verbose to access the global setting. - """ - if verbose is None: - return _get_verbosity() - else: - return verbose - - # --- Runnable methods --- - - @property - def OutputType(self) -> Type[str]: - """Get the input type for this runnable.""" - return str - - def _convert_input(self, input: LanguageModelInput) -> PromptValue: - if isinstance(input, PromptValue): - return input - elif isinstance(input, str): - return StringPromptValue(text=input) - elif isinstance(input, list): - return ChatPromptValue(messages=input) - else: - raise ValueError( - f"Invalid input type {type(input)}. " - "Must be a PromptValue, str, or list of BaseMessages." - ) - - def invoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> str: - config = config or {} - return ( - self.generate_prompt( - [self._convert_input(input)], - stop=stop, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ) - .generations[0][0] - .text - ) - - async def ainvoke( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> str: - config = config or {} - llm_result = await self.agenerate_prompt( - [self._convert_input(input)], - stop=stop, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ) - return llm_result.generations[0][0].text - - def batch( - self, - inputs: List[LanguageModelInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Any, - ) -> List[str]: - if not inputs: - return [] - - config = get_config_list(config, len(inputs)) - max_concurrency = config[0].get("max_concurrency") - - if max_concurrency is None: - try: - llm_result = self.generate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - run_name=[c.get("run_name") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] - except Exception as e: - if return_exceptions: - return cast(List[str], [e for _ in inputs]) - else: - raise e - else: - batches = [ - inputs[i : i + max_concurrency] - for i in range(0, len(inputs), max_concurrency) - ] - config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] - return [ - output - for i, batch in enumerate(batches) - for output in self.batch( - batch, - config=config[i * max_concurrency : (i + 1) * max_concurrency], - return_exceptions=return_exceptions, - **kwargs, - ) - ] - - async def abatch( - self, - inputs: List[LanguageModelInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Any, - ) -> List[str]: - if not inputs: - return [] - config = get_config_list(config, len(inputs)) - max_concurrency = config[0].get("max_concurrency") - - if max_concurrency is None: - try: - llm_result = await self.agenerate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - run_name=[c.get("run_name") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] - except Exception as e: - if return_exceptions: - return cast(List[str], [e for _ in inputs]) - else: - raise e - else: - batches = [ - inputs[i : i + max_concurrency] - for i in range(0, len(inputs), max_concurrency) - ] - config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] - return [ - output - for i, batch in enumerate(batches) - for output in await self.abatch( - batch, - config=config[i * max_concurrency : (i + 1) * max_concurrency], - return_exceptions=return_exceptions, - **kwargs, - ) - ] - - def stream( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> Iterator[str]: - if type(self)._stream == BaseLLM._stream: - # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop, **kwargs) - else: - prompt = self._convert_input(input).to_string() - config = config or {} - params = self.dict() - params["stop"] = stop - params = {**params, **kwargs} - options = {"stop": stop} - callback_manager = CallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, - ) - (run_manager,) = callback_manager.on_llm_start( - dumpd(self), - [prompt], - invocation_params=params, - options=options, - name=config.get("run_name"), - ) - try: - generation: Optional[GenerationChunk] = None - for chunk in self._stream( - prompt, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.text - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - run_manager.on_llm_error(e) - raise e - else: - run_manager.on_llm_end(LLMResult(generations=[[generation]])) - - async def astream( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> AsyncIterator[str]: - if type(self)._astream == BaseLLM._astream: - # model doesn't implement streaming, so use default implementation - yield await self.ainvoke(input, config=config, stop=stop, **kwargs) - else: - prompt = self._convert_input(input).to_string() - config = config or {} - params = self.dict() - params["stop"] = stop - params = {**params, **kwargs} - options = {"stop": stop} - callback_manager = AsyncCallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, - ) - (run_manager,) = await callback_manager.on_llm_start( - dumpd(self), - [prompt], - invocation_params=params, - options=options, - name=config.get("run_name"), - ) - try: - generation: Optional[GenerationChunk] = None - async for chunk in self._astream( - prompt, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.text - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - await run_manager.on_llm_error(e) - raise e - else: - await run_manager.on_llm_end(LLMResult(generations=[[generation]])) - - # --- Custom methods --- - - @abstractmethod - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompts.""" - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompts.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._generate, **kwargs), prompts, stop, run_manager - ) - - def _stream( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[GenerationChunk]: - raise NotImplementedError() - - def _astream( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[GenerationChunk]: - raise NotImplementedError() - - def generate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, - **kwargs: Any, - ) -> LLMResult: - prompt_strings = [p.to_string() for p in prompts] - return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) - - async def agenerate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, - **kwargs: Any, - ) -> LLMResult: - prompt_strings = [p.to_string() for p in prompts] - return await self.agenerate( - prompt_strings, stop=stop, callbacks=callbacks, **kwargs - ) - - def _generate_helper( - self, - prompts: List[str], - stop: Optional[List[str]], - run_managers: List[CallbackManagerForLLMRun], - new_arg_supported: bool, - **kwargs: Any, - ) -> LLMResult: - try: - output = ( - self._generate( - prompts, - stop=stop, - # TODO: support multiple run managers - run_manager=run_managers[0] if run_managers else None, - **kwargs, - ) - if new_arg_supported - else self._generate(prompts, stop=stop) - ) - except BaseException as e: - for run_manager in run_managers: - run_manager.on_llm_error(e) - raise e - flattened_outputs = output.flatten() - for manager, flattened_output in zip(run_managers, flattened_outputs): - manager.on_llm_end(flattened_output) - if run_managers: - output.run = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] - return output - - def generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, - *, - tags: Optional[Union[List[str], List[List[str]]]] = None, - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - run_name: Optional[Union[str, List[str]]] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - if not isinstance(prompts, list): - raise ValueError( - "Argument 'prompts' is expected to be of type List[str], received" - f" argument of type {type(prompts)}." - ) - # Create callback managers - if ( - isinstance(callbacks, list) - and callbacks - and ( - isinstance(callbacks[0], (list, BaseCallbackManager)) - or callbacks[0] is None - ) - ): - # We've received a list of callbacks args to apply to each input - assert len(callbacks) == len(prompts) - assert tags is None or ( - isinstance(tags, list) and len(tags) == len(prompts) - ) - assert metadata is None or ( - isinstance(metadata, list) and len(metadata) == len(prompts) - ) - assert run_name is None or ( - isinstance(run_name, list) and len(run_name) == len(prompts) - ) - callbacks = cast(List[Callbacks], callbacks) - tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) - metadata_list = cast( - List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) - ) - run_name_list = run_name or cast( - List[Optional[str]], ([None] * len(prompts)) - ) - callback_managers = [ - CallbackManager.configure( - callback, - self.callbacks, - self.verbose, - tag, - self.tags, - meta, - self.metadata, - ) - for callback, tag, meta in zip(callbacks, tags_list, metadata_list) - ] - else: - # We've received a single callbacks arg to apply to all inputs - callback_managers = [ - CallbackManager.configure( - cast(Callbacks, callbacks), - self.callbacks, - self.verbose, - cast(List[str], tags), - self.tags, - cast(Dict[str, Any], metadata), - self.metadata, - ) - ] * len(prompts) - run_name_list = [cast(Optional[str], run_name)] * len(prompts) - - params = self.dict() - params["stop"] = stop - options = {"stop": stop} - ( - existing_prompts, - llm_string, - missing_prompt_idxs, - missing_prompts, - ) = get_prompts(params, prompts) - disregard_cache = self.cache is not None and not self.cache - new_arg_supported = inspect.signature(self._generate).parameters.get( - "run_manager" - ) - if get_llm_cache() is None or disregard_cache: - if self.cache is not None and self.cache: - raise ValueError( - "Asked to cache, but no cache found at `langchain.cache`." - ) - run_managers = [ - callback_manager.on_llm_start( - dumpd(self), - [prompt], - invocation_params=params, - options=options, - name=run_name, - )[0] - for callback_manager, prompt, run_name in zip( - callback_managers, prompts, run_name_list - ) - ] - output = self._generate_helper( - prompts, stop, run_managers, bool(new_arg_supported), **kwargs - ) - return output - if len(missing_prompts) > 0: - run_managers = [ - callback_managers[idx].on_llm_start( - dumpd(self), - [prompts[idx]], - invocation_params=params, - options=options, - name=run_name_list[idx], - )[0] - for idx in missing_prompt_idxs - ] - new_results = self._generate_helper( - missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs - ) - llm_output = update_cache( - existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts - ) - run_info = ( - [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] - if run_managers - else None - ) - else: - llm_output = {} - run_info = None - generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output, run=run_info) - - async def _agenerate_helper( - self, - prompts: List[str], - stop: Optional[List[str]], - run_managers: List[AsyncCallbackManagerForLLMRun], - new_arg_supported: bool, - **kwargs: Any, - ) -> LLMResult: - try: - output = ( - await self._agenerate( - prompts, - stop=stop, - run_manager=run_managers[0] if run_managers else None, - **kwargs, - ) - if new_arg_supported - else await self._agenerate(prompts, stop=stop) - ) - except BaseException as e: - await asyncio.gather( - *[run_manager.on_llm_error(e) for run_manager in run_managers] - ) - raise e - flattened_outputs = output.flatten() - await asyncio.gather( - *[ - run_manager.on_llm_end(flattened_output) - for run_manager, flattened_output in zip( - run_managers, flattened_outputs - ) - ] - ) - if run_managers: - output.run = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] - return output - - async def agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, - *, - tags: Optional[Union[List[str], List[List[str]]]] = None, - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - run_name: Optional[Union[str, List[str]]] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - # Create callback managers - if isinstance(callbacks, list) and ( - isinstance(callbacks[0], (list, BaseCallbackManager)) - or callbacks[0] is None - ): - # We've received a list of callbacks args to apply to each input - assert len(callbacks) == len(prompts) - assert tags is None or ( - isinstance(tags, list) and len(tags) == len(prompts) - ) - assert metadata is None or ( - isinstance(metadata, list) and len(metadata) == len(prompts) - ) - assert run_name is None or ( - isinstance(run_name, list) and len(run_name) == len(prompts) - ) - callbacks = cast(List[Callbacks], callbacks) - tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) - metadata_list = cast( - List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) - ) - run_name_list = run_name or cast( - List[Optional[str]], ([None] * len(prompts)) - ) - callback_managers = [ - AsyncCallbackManager.configure( - callback, - self.callbacks, - self.verbose, - tag, - self.tags, - meta, - self.metadata, - ) - for callback, tag, meta in zip(callbacks, tags_list, metadata_list) - ] - else: - # We've received a single callbacks arg to apply to all inputs - callback_managers = [ - AsyncCallbackManager.configure( - cast(Callbacks, callbacks), - self.callbacks, - self.verbose, - cast(List[str], tags), - self.tags, - cast(Dict[str, Any], metadata), - self.metadata, - ) - ] * len(prompts) - run_name_list = [cast(Optional[str], run_name)] * len(prompts) - - params = self.dict() - params["stop"] = stop - options = {"stop": stop} - ( - existing_prompts, - llm_string, - missing_prompt_idxs, - missing_prompts, - ) = get_prompts(params, prompts) - disregard_cache = self.cache is not None and not self.cache - new_arg_supported = inspect.signature(self._agenerate).parameters.get( - "run_manager" - ) - if get_llm_cache() is None or disregard_cache: - if self.cache is not None and self.cache: - raise ValueError( - "Asked to cache, but no cache found at `langchain.cache`." - ) - run_managers = await asyncio.gather( - *[ - callback_manager.on_llm_start( - dumpd(self), - [prompt], - invocation_params=params, - options=options, - name=run_name, - ) - for callback_manager, prompt, run_name in zip( - callback_managers, prompts, run_name_list - ) - ] - ) - run_managers = [r[0] for r in run_managers] - output = await self._agenerate_helper( - prompts, stop, run_managers, bool(new_arg_supported), **kwargs - ) - return output - if len(missing_prompts) > 0: - run_managers = await asyncio.gather( - *[ - callback_managers[idx].on_llm_start( - dumpd(self), - [prompts[idx]], - invocation_params=params, - options=options, - name=run_name_list[idx], - ) - for idx in missing_prompt_idxs - ] - ) - run_managers = [r[0] for r in run_managers] - new_results = await self._agenerate_helper( - missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs - ) - llm_output = update_cache( - existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts - ) - run_info = ( - [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] - if run_managers - else None - ) - else: - llm_output = {} - run_info = None - generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output, run=run_info) - - def __call__( - self, - prompt: str, - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> str: - """Check Cache and run the LLM on the given prompt and input.""" - if not isinstance(prompt, str): - raise ValueError( - "Argument `prompt` is expected to be a string. Instead found " - f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " - "`generate` instead." - ) - return ( - self.generate( - [prompt], - stop=stop, - callbacks=callbacks, - tags=tags, - metadata=metadata, - **kwargs, - ) - .generations[0][0] - .text - ) - - async def _call_async( - self, - prompt: str, - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> str: - """Check Cache and run the LLM on the given prompt and input.""" - result = await self.agenerate( - [prompt], - stop=stop, - callbacks=callbacks, - tags=tags, - metadata=metadata, - **kwargs, - ) - return result.generations[0][0].text - - def predict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) - return self(text, stop=_stop, **kwargs) - - def predict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - text = get_buffer_string(messages) - if stop is None: - _stop = None - else: - _stop = list(stop) - content = self(text, stop=_stop, **kwargs) - return AIMessage(content=content) - - async def apredict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) - return await self._call_async(text, stop=_stop, **kwargs) - - async def apredict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - text = get_buffer_string(messages) - if stop is None: - _stop = None - else: - _stop = list(stop) - content = await self._call_async(text, stop=_stop, **kwargs) - return AIMessage(content=content) - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return {} - - def __str__(self) -> str: - """Get a string representation of the object for printing.""" - cls_name = f"\033[1m{self.__class__.__name__}\033[0m" - return f"{cls_name}\nParams: {self._identifying_params}" - - @property - @abstractmethod - def _llm_type(self) -> str: - """Return type of llm.""" - - def dict(self, **kwargs: Any) -> Dict: - """Return a dictionary of the LLM.""" - starter_dict = dict(self._identifying_params) - starter_dict["_type"] = self._llm_type - return starter_dict - - def save(self, file_path: Union[Path, str]) -> None: - """Save the LLM. - - Args: - file_path: Path to file to save the LLM to. - - Example: - .. code-block:: python - - llm.save(file_path="path/llm.yaml") - """ - # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path - - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - - # Fetch dictionary to save - prompt_dict = self.dict() - - if save_path.suffix == ".json": - with open(file_path, "w") as f: - json.dump(prompt_dict, f, indent=4) - elif save_path.suffix == ".yaml": - with open(file_path, "w") as f: - yaml.dump(prompt_dict, f, default_flow_style=False) - else: - raise ValueError(f"{save_path} must be json or yaml") - - -class LLM(BaseLLM): - """Base LLM abstract class. - - The purpose of this class is to expose a simpler interface for working - with LLMs, rather than expect the user to implement the full _generate method. - """ - - @abstractmethod - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Run the LLM on the given prompt and input.""" - - async def _acall( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Run the LLM on the given prompt and input.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._call, **kwargs), prompt, stop, run_manager - ) - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - # TODO: add caching here. - generations = [] - new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") - for prompt in prompts: - text = ( - self._call(prompt, stop=stop, run_manager=run_manager, **kwargs) - if new_arg_supported - else self._call(prompt, stop=stop, **kwargs) - ) - generations.append([Generation(text=text)]) - return LLMResult(generations=generations) - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - generations = [] - new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") - for prompt in prompts: - text = ( - await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) - if new_arg_supported - else await self._acall(prompt, stop=stop, **kwargs) - ) - generations.append([Generation(text=text)]) - return LLMResult(generations=generations) +__all__ = [ + "create_base_retry_decorator", + "get_prompts", + "update_cache", + "BaseLLM", + "LLM", +] diff --git a/libs/langchain/langchain/llms/baseten.py b/libs/langchain/langchain/llms/baseten.py index d8f1fb4e94e..2f75161ac97 100644 --- a/libs/langchain/langchain/llms/baseten.py +++ b/libs/langchain/langchain/llms/baseten.py @@ -1,9 +1,10 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/beam.py b/libs/langchain/langchain/llms/beam.py index 20aa0d0e2a0..7f661f5c237 100644 --- a/libs/langchain/langchain/llms/beam.py +++ b/libs/langchain/langchain/llms/beam.py @@ -7,10 +7,10 @@ import time from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index e21bd4d088b..7a1dd1b0a09 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -3,11 +3,12 @@ import warnings from abc import ABC from typing import Any, Dict, Iterator, List, Mapping, Optional +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.schema.output import GenerationChunk + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.schema.output import GenerationChunk from langchain.utilities.anthropic import ( get_num_tokens_anthropic, get_token_ids_anthropic, diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index 9b8a55eb268..00fe2c1683f 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/clarifai.py b/libs/langchain/langchain/llms/clarifai.py index 0b1dea43939..6da1851a703 100644 --- a/libs/langchain/langchain/llms/clarifai.py +++ b/libs/langchain/langchain/llms/clarifai.py @@ -1,11 +1,12 @@ import logging from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import Generation, LLMResult + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/cohere.py b/libs/langchain/langchain/llms/cohere.py index b24e10adb6f..754cfac2826 100644 --- a/libs/langchain/langchain/llms/cohere.py +++ b/libs/langchain/langchain/llms/cohere.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import Extra, Field, root_validator from tenacity import ( before_sleep_log, retry, @@ -17,8 +19,6 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/ctransformers.py b/libs/langchain/langchain/llms/ctransformers.py index ed859a1c80f..ff4961ec5ae 100644 --- a/libs/langchain/langchain/llms/ctransformers.py +++ b/libs/langchain/langchain/llms/ctransformers.py @@ -1,12 +1,13 @@ from functools import partial from typing import Any, Dict, List, Optional, Sequence +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import root_validator class CTransformers(LLM): diff --git a/libs/langchain/langchain/llms/ctranslate2.py b/libs/langchain/langchain/llms/ctranslate2.py index b6180d674de..060c3d36158 100644 --- a/libs/langchain/langchain/llms/ctranslate2.py +++ b/libs/langchain/langchain/llms/ctranslate2.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional, Union +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.output import Generation, LLMResult + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import BaseLLM -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.output import Generation, LLMResult class CTranslate2(BaseLLM): diff --git a/libs/langchain/langchain/llms/databricks.py b/libs/langchain/langchain/llms/databricks.py index 6488244ff4d..94eaedceec8 100644 --- a/libs/langchain/langchain/llms/databricks.py +++ b/libs/langchain/langchain/llms/databricks.py @@ -3,10 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional import requests - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from langchain.pydantic_v1 import ( +from langchain_core.pydantic_v1 import ( BaseModel, Extra, Field, @@ -15,6 +12,9 @@ from langchain.pydantic_v1 import ( validator, ) +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + __all__ = ["Databricks"] diff --git a/libs/langchain/langchain/llms/deepinfra.py b/libs/langchain/langchain/llms/deepinfra.py index 0f756615e75..44a8aab957c 100644 --- a/libs/langchain/langchain/llms/deepinfra.py +++ b/libs/langchain/langchain/llms/deepinfra.py @@ -2,13 +2,14 @@ import json from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional import aiohttp +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import LLM, GenerationChunk -from langchain.pydantic_v1 import Extra, root_validator +from langchain.llms.base import LLM +from langchain.schema.output import GenerationChunk from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/deepsparse.py b/libs/langchain/langchain/llms/deepsparse.py index aa7d8612afc..24f1d0e6cd2 100644 --- a/libs/langchain/langchain/llms/deepsparse.py +++ b/libs/langchain/langchain/llms/deepsparse.py @@ -1,13 +1,13 @@ # flake8: noqa from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union -from langchain.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.output import GenerationChunk +from langchain_core.schema.output import GenerationChunk class DeepSparse(LLM): diff --git a/libs/langchain/langchain/llms/edenai.py b/libs/langchain/langchain/llms/edenai.py index 8d333c46c1b..832151efa1e 100644 --- a/libs/langchain/langchain/llms/edenai.py +++ b/libs/langchain/langchain/llms/edenai.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, List, Literal, Optional from aiohttp import ClientSession +from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -10,7 +11,6 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index cb3ea2792ff..f77f919e6c9 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -2,13 +2,14 @@ import asyncio import time from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional +from langchain_core.runnables import RunnableConfig +from langchain_core.schema.language_model import LanguageModelInput + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.schema.language_model import LanguageModelInput -from langchain.schema.runnable import RunnableConfig class FakeListLLM(LLM): diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 72167d83d4e..737d691da8a 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -2,14 +2,15 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.schema.output import Generation, GenerationChunk, LLMResult +from langchain_core.utils import convert_to_secret_str + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.pydantic_v1 import Field, SecretStr, root_validator -from langchain.schema.output import Generation, GenerationChunk, LLMResult -from langchain.utils import convert_to_secret_str from langchain.utils.env import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/forefrontai.py b/libs/langchain/langchain/llms/forefrontai.py index 7d82b6896e1..3664144be26 100644 --- a/libs/langchain/langchain/llms/forefrontai.py +++ b/libs/langchain/langchain/llms/forefrontai.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/gigachat.py b/libs/langchain/langchain/llms/gigachat.py index b484d8bb415..5ea544885c2 100644 --- a/libs/langchain/langchain/llms/gigachat.py +++ b/libs/langchain/langchain/llms/gigachat.py @@ -4,14 +4,15 @@ import logging from functools import cached_property from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema.output import Generation, GenerationChunk, LLMResult + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import root_validator -from langchain.schema.output import Generation, GenerationChunk, LLMResult logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/google_palm.py b/libs/langchain/langchain/llms/google_palm.py index b15f383d1a4..f3aadc3a0d3 100644 --- a/libs/langchain/langchain/llms/google_palm.py +++ b/libs/langchain/langchain/llms/google_palm.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import Generation, LLMResult from tenacity import ( before_sleep_log, retry, @@ -13,8 +15,6 @@ from tenacity import ( from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import BaseLLM -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/gooseai.py b/libs/langchain/langchain/llms/gooseai.py index 67aeb1de186..db08a360a91 100644 --- a/libs/langchain/langchain/llms/gooseai.py +++ b/libs/langchain/langchain/llms/gooseai.py @@ -1,10 +1,12 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain.utils import convert_to_secret_str, get_from_dict_or_env +from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/gpt4all.py b/libs/langchain/langchain/llms/gpt4all.py index 3f9a397ffb3..2634fd9b96b 100644 --- a/libs/langchain/langchain/llms/gpt4all.py +++ b/libs/langchain/langchain/llms/gpt4all.py @@ -1,10 +1,11 @@ from functools import partial from typing import Any, Dict, List, Mapping, Optional, Set +from langchain_core.pydantic_v1 import Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator class GPT4All(LLM): diff --git a/libs/langchain/langchain/llms/gradient_ai.py b/libs/langchain/langchain/llms/gradient_ai.py index 1842c4b35bd..3ee979aee35 100644 --- a/libs/langchain/langchain/llms/gradient_ai.py +++ b/libs/langchain/langchain/llms/gradient_ai.py @@ -5,6 +5,8 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict import aiohttp import requests +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema import Generation, LLMResult from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -12,8 +14,6 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import BaseLLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/huggingface_endpoint.py b/libs/langchain/langchain/llms/huggingface_endpoint.py index 850e3d465cd..50c3f008076 100644 --- a/libs/langchain/langchain/llms/huggingface_endpoint.py +++ b/libs/langchain/langchain/llms/huggingface_endpoint.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env VALID_TASKS = ("text2text-generation", "text-generation", "summarization") diff --git a/libs/langchain/langchain/llms/huggingface_hub.py b/libs/langchain/langchain/llms/huggingface_hub.py index 45b326c4e9b..4bda8a6cb9a 100644 --- a/libs/langchain/langchain/llms/huggingface_hub.py +++ b/libs/langchain/langchain/llms/huggingface_hub.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" diff --git a/libs/langchain/langchain/llms/huggingface_pipeline.py b/libs/langchain/langchain/llms/huggingface_pipeline.py index b72ef2aa40c..095f9b186f6 100644 --- a/libs/langchain/langchain/llms/huggingface_pipeline.py +++ b/libs/langchain/langchain/llms/huggingface_pipeline.py @@ -4,11 +4,12 @@ import importlib.util import logging from typing import Any, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema import Generation, LLMResult + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import BaseLLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra -from langchain.schema import Generation, LLMResult DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" diff --git a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py index 683b2f4dde3..15267810252 100644 --- a/libs/langchain/langchain/llms/huggingface_text_gen_inference.py +++ b/libs/langchain/langchain/llms/huggingface_text_gen_inference.py @@ -1,14 +1,15 @@ import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema.output import GenerationChunk +from langchain_core.utils import get_pydantic_field_names + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema.output import GenerationChunk -from langchain.utils import get_pydantic_field_names logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/human.py b/libs/langchain/langchain/llms/human.py index d6ea10b8ec2..018781ffe8e 100644 --- a/libs/langchain/langchain/llms/human.py +++ b/libs/langchain/langchain/llms/human.py @@ -1,9 +1,10 @@ from typing import Any, Callable, List, Mapping, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Field def _display_prompt(prompt: str) -> None: diff --git a/libs/langchain/langchain/llms/javelin_ai_gateway.py b/libs/langchain/langchain/llms/javelin_ai_gateway.py index 80b04424b61..56ff8d278f0 100644 --- a/libs/langchain/langchain/llms/javelin_ai_gateway.py +++ b/libs/langchain/langchain/llms/javelin_ai_gateway.py @@ -2,12 +2,13 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import BaseModel, Extra + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import BaseModel, Extra # Ignoring type because below is valid pydantic code diff --git a/libs/langchain/langchain/llms/llamacpp.py b/libs/langchain/langchain/llms/llamacpp.py index 0e4b7e8f635..8da36d325b1 100644 --- a/libs/langchain/langchain/llms/llamacpp.py +++ b/libs/langchain/langchain/llms/llamacpp.py @@ -4,12 +4,13 @@ import logging from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.output import GenerationChunk +from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils.utils import build_extra_kwargs + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.output import GenerationChunk -from langchain.utils import get_pydantic_field_names -from langchain.utils.utils import build_extra_kwargs if TYPE_CHECKING: from llama_cpp import LlamaGrammar diff --git a/libs/langchain/langchain/llms/manifest.py b/libs/langchain/langchain/llms/manifest.py index 5e2416ab412..1fa5ffd0386 100644 --- a/libs/langchain/langchain/llms/manifest.py +++ b/libs/langchain/langchain/llms/manifest.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, root_validator class ManifestWrapper(LLM): diff --git a/libs/langchain/langchain/llms/minimax.py b/libs/langchain/langchain/llms/minimax.py index 4a6abc753ba..488f296d5a8 100644 --- a/libs/langchain/langchain/llms/minimax.py +++ b/libs/langchain/langchain/llms/minimax.py @@ -10,13 +10,13 @@ from typing import ( ) import requests +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/mlflow_ai_gateway.py b/libs/langchain/langchain/llms/mlflow_ai_gateway.py index dcb75c5190b..62b025d2155 100644 --- a/libs/langchain/langchain/llms/mlflow_ai_gateway.py +++ b/libs/langchain/langchain/llms/mlflow_ai_gateway.py @@ -2,9 +2,10 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import BaseModel, Extra + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import BaseModel, Extra # Ignoring type because below is valid pydantic code diff --git a/libs/langchain/langchain/llms/modal.py b/libs/langchain/langchain/llms/modal.py index c18b1599375..dc1535714ec 100644 --- a/libs/langchain/langchain/llms/modal.py +++ b/libs/langchain/langchain/llms/modal.py @@ -2,11 +2,11 @@ import logging from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/mosaicml.py b/libs/langchain/langchain/llms/mosaicml.py index 9103988f5cb..92a5d6d79ed 100644 --- a/libs/langchain/langchain/llms/mosaicml.py +++ b/libs/langchain/langchain/llms/mosaicml.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env INSTRUCTION_KEY = "### Instruction:" diff --git a/libs/langchain/langchain/llms/nlpcloud.py b/libs/langchain/langchain/llms/nlpcloud.py index d908e374e0f..c359284bd19 100644 --- a/libs/langchain/langchain/llms/nlpcloud.py +++ b/libs/langchain/langchain/llms/nlpcloud.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/octoai_endpoint.py b/libs/langchain/langchain/llms/octoai_endpoint.py index 450b0343edb..d563b1faa72 100644 --- a/libs/langchain/langchain/llms/octoai_endpoint.py +++ b/libs/langchain/langchain/llms/octoai_endpoint.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/ollama.py b/libs/langchain/langchain/llms/ollama.py index 43451e12fe7..3e6232f4397 100644 --- a/libs/langchain/langchain/llms/ollama.py +++ b/libs/langchain/langchain/llms/ollama.py @@ -2,13 +2,13 @@ import json from typing import Any, Dict, Iterator, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra +from langchain_core.schema import LLMResult +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.output import GenerationChunk from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import BaseLLM -from langchain.pydantic_v1 import Extra -from langchain.schema import LLMResult -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.output import GenerationChunk def _stream_response_to_generation_chunk( diff --git a/libs/langchain/langchain/llms/opaqueprompts.py b/libs/langchain/langchain/llms/opaqueprompts.py index af3ccc96721..9a588190454 100644 --- a/libs/langchain/langchain/llms/opaqueprompts.py +++ b/libs/langchain/langchain/llms/opaqueprompts.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema.language_model import BaseLanguageModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index f420173a6fe..1a0439709c7 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -21,17 +21,19 @@ from typing import ( Union, ) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import Generation, LLMResult +from langchain_core.schema.output import GenerationChunk +from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils.utils import build_extra_kwargs + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import Generation, LLMResult -from langchain.schema.output import GenerationChunk -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import get_from_dict_or_env from langchain.utils.openai import is_openai_v1 -from langchain.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/openllm.py b/libs/langchain/langchain/llms/openllm.py index f35088aff39..4472e971436 100644 --- a/libs/langchain/langchain/llms/openllm.py +++ b/libs/langchain/langchain/llms/openllm.py @@ -15,12 +15,13 @@ from typing import ( overload, ) +from langchain_core.pydantic_v1 import PrivateAttr + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import PrivateAttr if TYPE_CHECKING: import openllm diff --git a/libs/langchain/langchain/llms/openlm.py b/libs/langchain/langchain/llms/openlm.py index 156add03ecf..fcbf3af5842 100644 --- a/libs/langchain/langchain/llms/openlm.py +++ b/libs/langchain/langchain/llms/openlm.py @@ -1,7 +1,8 @@ from typing import Any, Dict +from langchain_core.pydantic_v1 import root_validator + from langchain.llms.openai import BaseOpenAI -from langchain.pydantic_v1 import root_validator class OpenLM(BaseOpenAI): diff --git a/libs/langchain/langchain/llms/pai_eas_endpoint.py b/libs/langchain/langchain/llms/pai_eas_endpoint.py index cd8634b5e2d..4e4bd7cde03 100644 --- a/libs/langchain/langchain/llms/pai_eas_endpoint.py +++ b/libs/langchain/langchain/llms/pai_eas_endpoint.py @@ -3,12 +3,12 @@ import logging from typing import Any, Dict, Iterator, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema.output import GenerationChunk from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import root_validator -from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/petals.py b/libs/langchain/langchain/llms/petals.py index 1069d6a79a9..ecd8df336e7 100644 --- a/libs/langchain/langchain/llms/petals.py +++ b/libs/langchain/langchain/llms/petals.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/pipelineai.py b/libs/langchain/langchain/llms/pipelineai.py index 248666e6f44..8572c86f4e6 100644 --- a/libs/langchain/langchain/llms/pipelineai.py +++ b/libs/langchain/langchain/llms/pipelineai.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/predibase.py b/libs/langchain/langchain/llms/predibase.py index 79c758dab3d..2448abdd05d 100644 --- a/libs/langchain/langchain/llms/predibase.py +++ b/libs/langchain/langchain/llms/predibase.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field class Predibase(LLM): diff --git a/libs/langchain/langchain/llms/predictionguard.py b/libs/langchain/langchain/llms/predictionguard.py index 6977381129b..850bcb763d5 100644 --- a/libs/langchain/langchain/llms/predictionguard.py +++ b/libs/langchain/langchain/llms/predictionguard.py @@ -1,10 +1,11 @@ import logging from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/promptlayer_openai.py b/libs/langchain/langchain/llms/promptlayer_openai.py index bbeffbf66be..35434481268 100644 --- a/libs/langchain/langchain/llms/promptlayer_openai.py +++ b/libs/langchain/langchain/llms/promptlayer_openai.py @@ -1,12 +1,13 @@ import datetime from typing import Any, List, Optional +from langchain_core.schema import LLMResult + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.openai import OpenAI, OpenAIChat -from langchain.schema import LLMResult class PromptLayerOpenAI(OpenAI): diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index e94cc1fad71..34e60851ed7 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -3,10 +3,11 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional +from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.schema.output import GenerationChunk + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/llms/rwkv.py b/libs/langchain/langchain/llms/rwkv.py index 8072b2b91b6..c0f709ab67f 100644 --- a/libs/langchain/langchain/llms/rwkv.py +++ b/libs/langchain/langchain/llms/rwkv.py @@ -5,10 +5,11 @@ Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.p """ from typing import Any, Dict, List, Mapping, Optional, Set +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import BaseModel, Extra, root_validator class RWKV(LLM, BaseModel): diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index a64e58f5db2..c95bb7e6fa7 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -4,10 +4,11 @@ import json from abc import abstractmethod from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union +from langchain_core.pydantic_v1 import Extra, root_validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator]) diff --git a/libs/langchain/langchain/llms/self_hosted.py b/libs/langchain/langchain/llms/self_hosted.py index 8e188814933..a8b94ea606f 100644 --- a/libs/langchain/langchain/llms/self_hosted.py +++ b/libs/langchain/langchain/llms/self_hosted.py @@ -3,10 +3,11 @@ import logging import pickle from typing import Any, Callable, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/self_hosted_hugging_face.py b/libs/langchain/langchain/llms/self_hosted_hugging_face.py index 77c00d81734..86d2879891f 100644 --- a/libs/langchain/langchain/llms/self_hosted_hugging_face.py +++ b/libs/langchain/langchain/llms/self_hosted_hugging_face.py @@ -2,10 +2,11 @@ import importlib.util import logging from typing import Any, Callable, List, Mapping, Optional +from langchain_core.pydantic_v1 import Extra + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" diff --git a/libs/langchain/langchain/llms/stochasticai.py b/libs/langchain/langchain/llms/stochasticai.py index 2afe7d695e4..0737a1e71d8 100644 --- a/libs/langchain/langchain/llms/stochasticai.py +++ b/libs/langchain/langchain/llms/stochasticai.py @@ -3,11 +3,11 @@ import time from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/symblai_nebula.py b/libs/langchain/langchain/llms/symblai_nebula.py index cefedd4cc6a..ddf238c7af9 100644 --- a/libs/langchain/langchain/llms/symblai_nebula.py +++ b/libs/langchain/langchain/llms/symblai_nebula.py @@ -3,6 +3,8 @@ import logging from typing import Any, Callable, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str from requests import ConnectTimeout, ReadTimeout, RequestException from tenacity import ( before_sleep_log, @@ -15,9 +17,7 @@ from tenacity import ( from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, SecretStr, root_validator -from langchain.utils import convert_to_secret_str -from langchain.utils.env import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai" DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate" diff --git a/libs/langchain/langchain/llms/textgen.py b/libs/langchain/langchain/llms/textgen.py index e962c7be278..93b387c85b8 100644 --- a/libs/langchain/langchain/llms/textgen.py +++ b/libs/langchain/langchain/llms/textgen.py @@ -3,14 +3,14 @@ import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Optional import requests +from langchain_core.pydantic_v1 import Field +from langchain_core.schema.output import GenerationChunk from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field -from langchain.schema.output import GenerationChunk logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/titan_takeoff.py b/libs/langchain/langchain/llms/titan_takeoff.py index b28d4884df9..af9a5097985 100644 --- a/libs/langchain/langchain/llms/titan_takeoff.py +++ b/libs/langchain/langchain/llms/titan_takeoff.py @@ -1,12 +1,12 @@ from typing import Any, Iterator, List, Mapping, Optional import requests +from langchain_core.schema.output import GenerationChunk from requests.exceptions import ConnectionError from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.output import GenerationChunk class TitanTakeoff(LLM): diff --git a/libs/langchain/langchain/llms/titan_takeoff_pro.py b/libs/langchain/langchain/llms/titan_takeoff_pro.py index 0a8dc78f3e6..cd3cc7bd29e 100644 --- a/libs/langchain/langchain/llms/titan_takeoff_pro.py +++ b/libs/langchain/langchain/llms/titan_takeoff_pro.py @@ -1,12 +1,12 @@ from typing import Any, Iterator, List, Mapping, Optional import requests +from langchain_core.schema.output import GenerationChunk from requests.exceptions import ConnectionError from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.output import GenerationChunk class TitanTakeoffPro(LLM): diff --git a/libs/langchain/langchain/llms/together.py b/libs/langchain/langchain/llms/together.py index c5445c1ad0c..46ada2e9d30 100644 --- a/libs/langchain/langchain/llms/together.py +++ b/libs/langchain/langchain/llms/together.py @@ -3,13 +3,13 @@ import logging from typing import Any, Dict, List, Optional from aiohttp import ClientSession +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.pydantic_v1 import Extra, root_validator from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/tongyi.py b/libs/langchain/langchain/llms/tongyi.py index 51dba49c0c4..bd5e3df5e53 100644 --- a/libs/langchain/langchain/llms/tongyi.py +++ b/libs/langchain/langchain/llms/tongyi.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema import Generation, LLMResult from requests.exceptions import HTTPError from tenacity import ( before_sleep_log, @@ -14,8 +16,6 @@ from tenacity import ( from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index fd6b31f40e3..6539e7647e5 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -13,17 +13,18 @@ from typing import ( Union, ) +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.schema import ( + Generation, + LLMResult, +) +from langchain_core.schema.output import GenerationChunk + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import ( - Generation, - LLMResult, -) -from langchain.schema.output import GenerationChunk from langchain.utilities.vertexai import ( get_client_info, init_vertexai, diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py index f33e3cef96d..e7fced22c13 100644 --- a/libs/langchain/langchain/llms/vllm.py +++ b/libs/langchain/langchain/llms/vllm.py @@ -1,10 +1,11 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.schema.output import Generation, LLMResult + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import BaseLLM from langchain.llms.openai import BaseOpenAI -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.output import Generation, LLMResult from langchain.utils.openai import is_openai_v1 diff --git a/libs/langchain/langchain/llms/writer.py b/libs/langchain/langchain/llms/writer.py index 54a8a5fc215..45b56e389d2 100644 --- a/libs/langchain/langchain/llms/writer.py +++ b/libs/langchain/langchain/llms/writer.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests +from langchain_core.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/yandex.py b/libs/langchain/langchain/llms/yandex.py index 58a2c831685..c28c89a0834 100644 --- a/libs/langchain/langchain/llms/yandex.py +++ b/libs/langchain/langchain/llms/yandex.py @@ -1,13 +1,14 @@ from typing import Any, Dict, List, Mapping, Optional +from langchain_core.load.serializable import Serializable +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/load/__init__.py b/libs/langchain/langchain/load/__init__.py index 1cceed481cf..2cac05bb527 100644 --- a/libs/langchain/langchain/load/__init__.py +++ b/libs/langchain/langchain/load/__init__.py @@ -1,6 +1,6 @@ """Serialization and deserialization.""" -from langchain.load.dump import dumpd, dumps -from langchain.load.load import load, loads +from langchain_core.load.dump import dumpd, dumps +from langchain_core.load.load import load, loads __all__ = [ "dumpd", diff --git a/libs/langchain/langchain/load/dump.py b/libs/langchain/langchain/load/dump.py index 151903712ad..fd9552018c8 100644 --- a/libs/langchain/langchain/load/dump.py +++ b/libs/langchain/langchain/load/dump.py @@ -1,26 +1,3 @@ -import json -from typing import Any, Dict +from langchain_core.load.dump import default, dumpd, dumps -from langchain.load.serializable import Serializable, to_json_not_implemented - - -def default(obj: Any) -> Any: - """Return a default value for a Serializable object or - a SerializedNotImplemented object.""" - if isinstance(obj, Serializable): - return obj.to_json() - else: - return to_json_not_implemented(obj) - - -def dumps(obj: Any, *, pretty: bool = False) -> str: - """Return a json string representation of an object.""" - if pretty: - return json.dumps(obj, default=default, indent=2) - else: - return json.dumps(obj, default=default) - - -def dumpd(obj: Any) -> Dict[str, Any]: - """Return a json dict representation of an object.""" - return json.loads(dumps(obj)) +__all__ = ["default", "dumps", "dumpd"] diff --git a/libs/langchain/langchain/load/load.py b/libs/langchain/langchain/load/load.py index 5d8b7ccd33e..5a32ea081f0 100644 --- a/libs/langchain/langchain/load/load.py +++ b/libs/langchain/langchain/load/load.py @@ -1,126 +1,3 @@ -import importlib -import json -import os -from typing import Any, Dict, List, Optional +from langchain_core.load.load import Reviver, load, loads -from langchain.load.serializable import Serializable - - -class Reviver: - """Reviver for JSON objects.""" - - def __init__( - self, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, - ) -> None: - self.secrets_map = secrets_map or dict() - # By default only support langchain, but user can pass in additional namespaces - self.valid_namespaces = ( - ["langchain", *valid_namespaces] if valid_namespaces else ["langchain"] - ) - - def __call__(self, value: Dict[str, Any]) -> Any: - if ( - value.get("lc", None) == 1 - and value.get("type", None) == "secret" - and value.get("id", None) is not None - ): - [key] = value["id"] - if key in self.secrets_map: - return self.secrets_map[key] - else: - if key in os.environ and os.environ[key]: - return os.environ[key] - raise KeyError(f'Missing key "{key}" in load(secrets_map)') - - if ( - value.get("lc", None) == 1 - and value.get("type", None) == "not_implemented" - and value.get("id", None) is not None - ): - raise NotImplementedError( - "Trying to load an object that doesn't implement " - f"serialization: {value}" - ) - - if ( - value.get("lc", None) == 1 - and value.get("type", None) == "constructor" - and value.get("id", None) is not None - ): - [*namespace, name] = value["id"] - - if namespace[0] not in self.valid_namespaces: - raise ValueError(f"Invalid namespace: {value}") - - # The root namespace "langchain" is not a valid identifier. - if len(namespace) == 1 and namespace[0] == "langchain": - raise ValueError(f"Invalid namespace: {value}") - - mod = importlib.import_module(".".join(namespace)) - cls = getattr(mod, name) - - # The class must be a subclass of Serializable. - if not issubclass(cls, Serializable): - raise ValueError(f"Invalid namespace: {value}") - - # We don't need to recurse on kwargs - # as json.loads will do that for us. - kwargs = value.get("kwargs", dict()) - return cls(**kwargs) - - return value - - -def loads( - text: str, - *, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, -) -> Any: - """Revive a LangChain class from a JSON string. - Equivalent to `load(json.loads(text))`. - - Args: - text: The string to load. - secrets_map: A map of secrets to load. - valid_namespaces: A list of additional namespaces (modules) - to allow to be deserialized. - - Returns: - Revived LangChain objects. - """ - return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces)) - - -def load( - obj: Any, - *, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, -) -> Any: - """Revive a LangChain class from a JSON object. Use this if you already - have a parsed JSON object, eg. from `json.load` or `orjson.loads`. - - Args: - obj: The object to load. - secrets_map: A map of secrets to load. - valid_namespaces: A list of additional namespaces (modules) - to allow to be deserialized. - - Returns: - Revived LangChain objects. - """ - reviver = Reviver(secrets_map, valid_namespaces) - - def _load(obj: Any) -> Any: - if isinstance(obj, dict): - # Need to revive leaf nodes before reviving this node - loaded_obj = {k: _load(v) for k, v in obj.items()} - return reviver(loaded_obj) - if isinstance(obj, list): - return [_load(o) for o in obj] - return obj - - return _load(obj) +__all__ = ["Reviver", "loads", "load"] diff --git a/libs/langchain/langchain/load/serializable.py b/libs/langchain/langchain/load/serializable.py index 368f2f0d323..d1c0772d61d 100644 --- a/libs/langchain/langchain/load/serializable.py +++ b/libs/langchain/langchain/load/serializable.py @@ -1,207 +1,19 @@ -from abc import ABC -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast +from langchain_core.load.serializable import ( + BaseSerialized, + Serializable, + SerializedConstructor, + SerializedNotImplemented, + SerializedSecret, + to_json_not_implemented, + try_neq_default, +) -from langchain.pydantic_v1 import BaseModel, PrivateAttr - - -class BaseSerialized(TypedDict): - """Base class for serialized objects.""" - - lc: int - id: List[str] - - -class SerializedConstructor(BaseSerialized): - """Serialized constructor.""" - - type: Literal["constructor"] - kwargs: Dict[str, Any] - - -class SerializedSecret(BaseSerialized): - """Serialized secret.""" - - type: Literal["secret"] - - -class SerializedNotImplemented(BaseSerialized): - """Serialized not implemented.""" - - type: Literal["not_implemented"] - repr: Optional[str] - - -def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: - try: - return model.__fields__[key].get_default() != value - except Exception: - return True - - -class Serializable(BaseModel, ABC): - """Serializable base class.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Is this class serializable?""" - return False - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object. - - For example, if the class is `langchain.llms.openai.OpenAI`, then the - namespace is ["langchain", "llms", "openai"] - """ - return cls.__module__.split(".") - - @property - def lc_secrets(self) -> Dict[str, str]: - """A map of constructor argument names to secret ids. - - For example, - {"openai_api_key": "OPENAI_API_KEY"} - """ - return dict() - - @property - def lc_attributes(self) -> Dict: - """List of attribute names that should be included in the serialized kwargs. - - These attributes must be accepted by the constructor. - """ - return {} - - @classmethod - def lc_id(cls) -> List[str]: - """A unique identifier for this class for serialization purposes. - - The unique identifier is a list of strings that describes the path - to the object. - """ - return [*cls.get_lc_namespace(), cls.__name__] - - class Config: - extra = "ignore" - - def __repr_args__(self) -> Any: - return [ - (k, v) - for k, v in super().__repr_args__() - if (k not in self.__fields__ or try_neq_default(v, k, self)) - ] - - _lc_kwargs = PrivateAttr(default_factory=dict) - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._lc_kwargs = kwargs - - def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: - if not self.is_lc_serializable(): - return self.to_json_not_implemented() - - secrets = dict() - # Get latest values for kwargs if there is an attribute with same name - lc_kwargs = { - k: getattr(self, k, v) - for k, v in self._lc_kwargs.items() - if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore - } - - # Merge the lc_secrets and lc_attributes from every class in the MRO - for cls in [None, *self.__class__.mro()]: - # Once we get to Serializable, we're done - if cls is Serializable: - break - - if cls: - deprecated_attributes = [ - "lc_namespace", - "lc_serializable", - ] - - for attr in deprecated_attributes: - if hasattr(cls, attr): - raise ValueError( - f"Class {self.__class__} has a deprecated " - f"attribute {attr}. Please use the corresponding " - f"classmethod instead." - ) - - # Get a reference to self bound to each class in the MRO - this = cast(Serializable, self if cls is None else super(cls, self)) - - secrets.update(this.lc_secrets) - lc_kwargs.update(this.lc_attributes) - - # include all secrets, even if not specified in kwargs - # as these secrets may be passed as an environment variable instead - for key in secrets.keys(): - secret_value = getattr(self, key, None) or lc_kwargs.get(key) - if secret_value is not None: - lc_kwargs.update({key: secret_value}) - - return { - "lc": 1, - "type": "constructor", - "id": self.lc_id(), - "kwargs": lc_kwargs - if not secrets - else _replace_secrets(lc_kwargs, secrets), - } - - def to_json_not_implemented(self) -> SerializedNotImplemented: - return to_json_not_implemented(self) - - -def _replace_secrets( - root: Dict[Any, Any], secrets_map: Dict[str, str] -) -> Dict[Any, Any]: - result = root.copy() - for path, secret_id in secrets_map.items(): - [*parts, last] = path.split(".") - current = result - for part in parts: - if part not in current: - break - current[part] = current[part].copy() - current = current[part] - if last in current: - current[last] = { - "lc": 1, - "type": "secret", - "id": [secret_id], - } - return result - - -def to_json_not_implemented(obj: object) -> SerializedNotImplemented: - """Serialize a "not implemented" object. - - Args: - obj: object to serialize - - Returns: - SerializedNotImplemented - """ - _id: List[str] = [] - try: - if hasattr(obj, "__name__"): - _id = [*obj.__module__.split("."), obj.__name__] - elif hasattr(obj, "__class__"): - _id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] - except Exception: - pass - - result: SerializedNotImplemented = { - "lc": 1, - "type": "not_implemented", - "id": _id, - "repr": None, - } - try: - result["repr"] = repr(obj) - except Exception: - pass - return result +__all__ = [ + "BaseSerialized", + "SerializedConstructor", + "SerializedSecret", + "SerializedNotImplemented", + "try_neq_default", + "Serializable", + "to_json_not_implemented", +] diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index 4afad7dffd3..f4c6d356d6c 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.utils import get_prompt_input_key -from langchain.pydantic_v1 import root_validator -from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationBufferMemory(BaseChatMemory): diff --git a/libs/langchain/langchain/memory/buffer_window.py b/libs/langchain/langchain/memory/buffer_window.py index 05b883e6d7d..50ddbad655a 100644 --- a/libs/langchain/langchain/memory/buffer_window.py +++ b/libs/langchain/langchain/memory/buffer_window.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Union +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationBufferWindowMemory(BaseChatMemory): diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index cb49c12d964..882fc1a4e27 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -1,10 +1,11 @@ from abc import ABC from typing import Any, Dict, Optional, Tuple +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseChatMessageHistory, BaseMemory + from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.utils import get_prompt_input_key -from langchain.pydantic_v1 import Field -from langchain.schema import BaseChatMessageHistory, BaseMemory class BaseChatMemory(BaseMemory, ABC): diff --git a/libs/langchain/langchain/memory/chat_message_histories/cassandra.py b/libs/langchain/langchain/memory/chat_message_histories/cassandra.py index d50be8d5a6c..73baa950714 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cassandra.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cassandra.py @@ -8,10 +8,14 @@ from typing import List if typing.TYPE_CHECKING: from cassandra.cluster import Session -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) DEFAULT_TABLE_NAME = "message_store" DEFAULT_TTL_SECONDS = None diff --git a/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py b/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py index ccf6546d037..e01a2343880 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py @@ -5,10 +5,14 @@ import logging from types import TracebackType from typing import TYPE_CHECKING, Any, List, Optional, Type -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict +from langchain_core.schema.messages import ( + BaseMessage, + messages_from_dict, + messages_to_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py index 353356ef5ba..ff4d8960811 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Dict, List, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( BaseMessage, _message_to_dict, messages_from_dict, diff --git a/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py b/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py index e3b3397c1bf..46c49730537 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py +++ b/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py @@ -3,8 +3,12 @@ import logging from time import time from typing import TYPE_CHECKING, Any, Dict, List, Optional -from langchain.schema import BaseChatMessageHistory -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema import BaseChatMessageHistory +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) if TYPE_CHECKING: from elasticsearch import Elasticsearch diff --git a/libs/langchain/langchain/memory/chat_message_histories/file.py b/libs/langchain/langchain/memory/chat_message_histories/file.py index 912ff740cb8..b9d2943cf09 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/file.py +++ b/libs/langchain/langchain/memory/chat_message_histories/file.py @@ -3,10 +3,14 @@ import logging from pathlib import Path from typing import List -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict +from langchain_core.schema.messages import ( + BaseMessage, + messages_from_dict, + messages_to_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/firestore.py b/libs/langchain/langchain/memory/chat_message_histories/firestore.py index e1f2435b27f..d8aae0becc0 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/firestore.py +++ b/libs/langchain/langchain/memory/chat_message_histories/firestore.py @@ -4,10 +4,14 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, List, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict +from langchain_core.schema.messages import ( + BaseMessage, + messages_from_dict, + messages_to_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/in_memory.py b/libs/langchain/langchain/memory/chat_message_histories/in_memory.py index 53fbbb201ea..3dc5142e461 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/in_memory.py +++ b/libs/langchain/langchain/memory/chat_message_histories/in_memory.py @@ -1,10 +1,10 @@ from typing import List -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage +from langchain_core.schema.messages import BaseMessage class ChatMessageHistory(BaseChatMessageHistory, BaseModel): diff --git a/libs/langchain/langchain/memory/chat_message_histories/momento.py b/libs/langchain/langchain/memory/chat_message_histories/momento.py index 12ffbaf4c2b..c2d70b88b6a 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/momento.py +++ b/libs/langchain/langchain/memory/chat_message_histories/momento.py @@ -4,10 +4,15 @@ import json from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) + from langchain.utils import get_from_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/memory/chat_message_histories/mongodb.py b/libs/langchain/langchain/memory/chat_message_histories/mongodb.py index 5cc3af8dbef..20c8bfb5cf0 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/mongodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/mongodb.py @@ -2,10 +2,14 @@ import json import logging from typing import List -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/neo4j.py b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py index dfbf75cc304..b198ab05b24 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/neo4j.py +++ b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py @@ -1,7 +1,8 @@ from typing import List, Optional, Union -from langchain.schema import BaseChatMessageHistory -from langchain.schema.messages import BaseMessage, messages_from_dict +from langchain_core.schema import BaseChatMessageHistory +from langchain_core.schema.messages import BaseMessage, messages_from_dict + from langchain.utils import get_from_env diff --git a/libs/langchain/langchain/memory/chat_message_histories/postgres.py b/libs/langchain/langchain/memory/chat_message_histories/postgres.py index fb1d3b34886..19857662293 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/postgres.py +++ b/libs/langchain/langchain/memory/chat_message_histories/postgres.py @@ -2,10 +2,14 @@ import json import logging from typing import List -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/redis.py b/libs/langchain/langchain/memory/chat_message_histories/redis.py index 13904a463aa..6939d2d7af1 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/redis.py +++ b/libs/langchain/langchain/memory/chat_message_histories/redis.py @@ -2,10 +2,15 @@ import json import logging from typing import List, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) + from langchain.utilities.redis import get_client logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py index 2e690736917..c995cf7338b 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py @@ -3,8 +3,12 @@ from time import sleep from typing import Any, Callable, List, Union from uuid import uuid4 -from langchain.schema import BaseChatMessageHistory -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema import BaseChatMessageHistory +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) class RocksetChatMessageHistory(BaseChatMessageHistory): diff --git a/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py index 74489abdd59..7d1c9274583 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py @@ -6,10 +6,14 @@ from typing import ( List, ) -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/sql.py b/libs/langchain/langchain/memory/chat_message_histories/sql.py index 610d049c51e..e83ca7a971e 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/sql.py +++ b/libs/langchain/langchain/memory/chat_message_histories/sql.py @@ -9,12 +9,15 @@ try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) +from sqlalchemy.orm import sessionmaker logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/streamlit.py b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py index 34280356f74..111b86fc38a 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/streamlit.py +++ b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py @@ -1,9 +1,9 @@ from typing import List -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage +from langchain_core.schema.messages import BaseMessage class StreamlitChatMessageHistory(BaseChatMessageHistory): diff --git a/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py b/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py index 8cc3ac00f40..94c83d1f684 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py +++ b/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py @@ -2,10 +2,14 @@ import json import logging from typing import List, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/chat_message_histories/xata.py b/libs/langchain/langchain/memory/chat_message_histories/xata.py index a51da97a46f..e9e98af5256 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/xata.py +++ b/libs/langchain/langchain/memory/chat_message_histories/xata.py @@ -1,10 +1,14 @@ import json from typing import List -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict +from langchain_core.schema.messages import ( + BaseMessage, + _message_to_dict, + messages_from_dict, +) class XataChatMessageHistory(BaseChatMessageHistory): diff --git a/libs/langchain/langchain/memory/chat_message_histories/zep.py b/libs/langchain/langchain/memory/chat_message_histories/zep.py index 853e6bfe1c5..a9709d2218c 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/zep.py +++ b/libs/langchain/langchain/memory/chat_message_histories/zep.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional -from langchain.schema import ( +from langchain_core.schema import ( BaseChatMessageHistory, ) -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, BaseMessage, HumanMessage, diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py index 8cf76046237..f67064e7275 100644 --- a/libs/langchain/langchain/memory/combined.py +++ b/libs/langchain/langchain/memory/combined.py @@ -1,9 +1,10 @@ import warnings from typing import Any, Dict, List, Set +from langchain_core.pydantic_v1 import validator +from langchain_core.schema import BaseMemory + from langchain.memory.chat_memory import BaseChatMemory -from langchain.pydantic_v1 import validator -from langchain.schema import BaseMemory class CombinedMemory(BaseMemory): diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index a4ae4c132cd..fe7576573e6 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -3,6 +3,11 @@ from abc import ABC, abstractmethod from itertools import islice from typing import Any, Dict, Iterable, List, Optional +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( @@ -10,10 +15,6 @@ from langchain.memory.prompt import ( ENTITY_SUMMARIZATION_PROMPT, ) from langchain.memory.utils import get_prompt_input_key -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.utilities.redis import get_client logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/memory/kg.py b/libs/langchain/langchain/memory/kg.py index 4f79bd76a57..831f649b606 100644 --- a/libs/langchain/langchain/memory/kg.py +++ b/libs/langchain/langchain/memory/kg.py @@ -1,5 +1,10 @@ from typing import Any, Dict, List, Type, Union +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BasePromptTemplate +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, SystemMessage, get_buffer_string + from langchain.chains.llm import LLMChain from langchain.graphs import NetworkxEntityGraph from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples @@ -9,10 +14,6 @@ from langchain.memory.prompt import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, ) from langchain.memory.utils import get_prompt_input_key -from langchain.pydantic_v1 import Field -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string class ConversationKGMemory(BaseChatMemory): diff --git a/libs/langchain/langchain/memory/motorhead_memory.py b/libs/langchain/langchain/memory/motorhead_memory.py index da9d9d82a3b..a5607c1ba05 100644 --- a/libs/langchain/langchain/memory/motorhead_memory.py +++ b/libs/langchain/langchain/memory/motorhead_memory.py @@ -1,9 +1,9 @@ from typing import Any, Dict, List, Optional import requests +from langchain_core.schema.messages import get_buffer_string from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.messages import get_buffer_string MANAGED_URL = "https://api.getmetal.io/v1/motorhead" # LOCAL_URL = "http://localhost:8080" diff --git a/libs/langchain/langchain/memory/prompt.py b/libs/langchain/langchain/memory/prompt.py index af74b6554df..c16e8e24931 100644 --- a/libs/langchain/langchain/memory/prompt.py +++ b/libs/langchain/langchain/memory/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI. diff --git a/libs/langchain/langchain/memory/readonly.py b/libs/langchain/langchain/memory/readonly.py index 78a6769b0a3..c037a90b51a 100644 --- a/libs/langchain/langchain/memory/readonly.py +++ b/libs/langchain/langchain/memory/readonly.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from langchain.schema import BaseMemory +from langchain_core.schema import BaseMemory class ReadOnlySharedMemory(BaseMemory): diff --git a/libs/langchain/langchain/memory/simple.py b/libs/langchain/langchain/memory/simple.py index 00caa9d5b29..03fb4416743 100644 --- a/libs/langchain/langchain/memory/simple.py +++ b/libs/langchain/langchain/memory/simple.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from langchain.schema import BaseMemory +from langchain_core.schema import BaseMemory class SimpleMemory(BaseMemory): diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index a6cbfa727c3..1f78ac24daf 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -2,16 +2,17 @@ from __future__ import annotations from typing import Any, Dict, List, Type -from langchain.chains.llm import LLMChain -from langchain.memory.chat_memory import BaseChatMemory -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import ( BaseChatMessageHistory, BasePromptTemplate, ) -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, SystemMessage, get_buffer_string + +from langchain.chains.llm import LLMChain +from langchain.memory.chat_memory import BaseChatMemory +from langchain.memory.prompt import SUMMARY_PROMPT class SummarizerMixin(BaseModel): diff --git a/libs/langchain/langchain/memory/summary_buffer.py b/libs/langchain/langchain/memory/summary_buffer.py index f0a31b41244..d9c7bf29d70 100644 --- a/libs/langchain/langchain/memory/summary_buffer.py +++ b/libs/langchain/langchain/memory/summary_buffer.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin -from langchain.pydantic_v1 import root_validator -from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): diff --git a/libs/langchain/langchain/memory/token_buffer.py b/libs/langchain/langchain/memory/token_buffer.py index 8c9c37460ff..57cf6fb5dfc 100644 --- a/libs/langchain/langchain/memory/token_buffer.py +++ b/libs/langchain/langchain/memory/token_buffer.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, get_buffer_string + from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationTokenBufferMemory(BaseChatMemory): diff --git a/libs/langchain/langchain/memory/vectorstore.py b/libs/langchain/langchain/memory/vectorstore.py index a35e74b6b33..d76be477b1c 100644 --- a/libs/langchain/langchain/memory/vectorstore.py +++ b/libs/langchain/langchain/memory/vectorstore.py @@ -2,11 +2,12 @@ from typing import Any, Dict, List, Optional, Sequence, Union +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import Document +from langchain_core.schema.vectorstore import VectorStoreRetriever + from langchain.memory.chat_memory import BaseMemory from langchain.memory.utils import get_prompt_input_key -from langchain.pydantic_v1 import Field -from langchain.schema import Document -from langchain.schema.vectorstore import VectorStoreRetriever class VectorStoreRetrieverMemory(BaseMemory): diff --git a/libs/langchain/langchain/model_laboratory.py b/libs/langchain/langchain/model_laboratory.py index 87c44a21143..2bbe5cb7b67 100644 --- a/libs/langchain/langchain/model_laboratory.py +++ b/libs/langchain/langchain/model_laboratory.py @@ -3,11 +3,12 @@ from __future__ import annotations from typing import List, Optional, Sequence +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.utils.input import get_color_mapping, print_text + from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM -from langchain.prompts.prompt import PromptTemplate -from langchain.utils.input import get_color_mapping, print_text class ModelLaboratory: diff --git a/libs/langchain/langchain/output_parsers/boolean.py b/libs/langchain/langchain/output_parsers/boolean.py index f0990b8e050..5d704ba94a4 100644 --- a/libs/langchain/langchain/output_parsers/boolean.py +++ b/libs/langchain/langchain/output_parsers/boolean.py @@ -1,4 +1,4 @@ -from langchain.schema import BaseOutputParser +from langchain_core.schema import BaseOutputParser class BooleanOutputParser(BaseOutputParser[bool]): diff --git a/libs/langchain/langchain/output_parsers/combining.py b/libs/langchain/langchain/output_parsers/combining.py index 2ebbb8a5ff5..300eec1c175 100644 --- a/libs/langchain/langchain/output_parsers/combining.py +++ b/libs/langchain/langchain/output_parsers/combining.py @@ -2,8 +2,8 @@ from __future__ import annotations from typing import Any, Dict, List -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseOutputParser +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseOutputParser class CombiningOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/output_parsers/datetime.py b/libs/langchain/langchain/output_parsers/datetime.py index bec68bb1cb1..5113bb07efc 100644 --- a/libs/langchain/langchain/output_parsers/datetime.py +++ b/libs/langchain/langchain/output_parsers/datetime.py @@ -2,7 +2,8 @@ import random from datetime import datetime, timedelta from typing import List -from langchain.schema import BaseOutputParser, OutputParserException +from langchain_core.schema import BaseOutputParser, OutputParserException + from langchain.utils import comma_list diff --git a/libs/langchain/langchain/output_parsers/enum.py b/libs/langchain/langchain/output_parsers/enum.py index b066e5ca144..a6baadbfd36 100644 --- a/libs/langchain/langchain/output_parsers/enum.py +++ b/libs/langchain/langchain/output_parsers/enum.py @@ -1,8 +1,8 @@ from enum import Enum from typing import Any, Dict, List, Type -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseOutputParser, OutputParserException +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseOutputParser, OutputParserException class EnumOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index b258ed3344a..37c186cc9ab 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -2,9 +2,14 @@ from __future__ import annotations from typing import Any, TypeVar +from langchain_core.schema import ( + BaseOutputParser, + BasePromptTemplate, + OutputParserException, +) +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT -from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException -from langchain.schema.language_model import BaseLanguageModel T = TypeVar("T") diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 53c2b1c3c98..e4282c5975b 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -6,8 +6,7 @@ from json import JSONDecodeError from typing import Any, Callable, List, Optional import jsonpatch - -from langchain.schema.output_parser import ( +from langchain_core.schema.output_parser import ( BaseCumulativeTransformOutputParser, OutputParserException, ) diff --git a/libs/langchain/langchain/output_parsers/list.py b/libs/langchain/langchain/output_parsers/list.py index a1b955ef0a4..b5ffd8a3ab7 100644 --- a/libs/langchain/langchain/output_parsers/list.py +++ b/libs/langchain/langchain/output_parsers/list.py @@ -1,79 +1,13 @@ -from __future__ import annotations +from langchain_core.output_parsers.list import ( + CommaSeparatedListOutputParser, + ListOutputParser, + MarkdownListOutputParser, + NumberedListOutputParser, +) -import re -from abc import abstractmethod -from typing import List - -from langchain.schema import BaseOutputParser - - -class ListOutputParser(BaseOutputParser[List[str]]): - """Parse the output of an LLM call to a list.""" - - @property - def _type(self) -> str: - return "list" - - @abstractmethod - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - - -class CommaSeparatedListOutputParser(ListOutputParser): - """Parse the output of an LLM call to a comma-separated list.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - def get_format_instructions(self) -> str: - return ( - "Your response should be a list of comma separated values, " - "eg: `foo, bar, baz`" - ) - - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - return text.strip().split(", ") - - @property - def _type(self) -> str: - return "comma-separated-list" - - -class NumberedListOutputParser(ListOutputParser): - """Parse a numbered list.""" - - def get_format_instructions(self) -> str: - return ( - "Your response should be a numbered list with each item on a new line. " - "For example: \n\n1. foo\n\n2. bar\n\n3. baz" - ) - - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - pattern = r"\d+\.\s([^\n]+)" - - # Extract the text of each item - matches = re.findall(pattern, text) - return matches - - @property - def _type(self) -> str: - return "numbered-list" - - -class MarkdownListOutputParser(ListOutputParser): - """Parse a markdown list.""" - - def get_format_instructions(self) -> str: - return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" - - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - pattern = r"-\s([^\n]+)" - return re.findall(pattern, text) - - @property - def _type(self) -> str: - return "markdown-list" +__all__ = [ + "ListOutputParser", + "CommaSeparatedListOutputParser", + "NumberedListOutputParser", + "MarkdownListOutputParser", +] diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index 78b3a989b4c..8a5259ac0b8 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -3,19 +3,19 @@ import json from typing import Any, Dict, List, Optional, Type, Union import jsonpatch - -from langchain.output_parsers.json import parse_partial_json -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import ( ChatGeneration, Generation, OutputParserException, ) -from langchain.schema.output_parser import ( +from langchain_core.schema.output_parser import ( BaseCumulativeTransformOutputParser, BaseGenerationOutputParser, ) +from langchain.output_parsers.json import parse_partial_json + class OutputFunctionsParser(BaseGenerationOutputParser[Any]): """Parse an output that is one of sets of values.""" diff --git a/libs/langchain/langchain/output_parsers/openai_tools.py b/libs/langchain/langchain/output_parsers/openai_tools.py index 2df18154318..52547182639 100644 --- a/libs/langchain/langchain/output_parsers/openai_tools.py +++ b/libs/langchain/langchain/output_parsers/openai_tools.py @@ -2,13 +2,13 @@ import copy import json from typing import Any, List, Type -from langchain.pydantic_v1 import BaseModel -from langchain.schema import ( +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import ( ChatGeneration, Generation, OutputParserException, ) -from langchain.schema.output_parser import ( +from langchain_core.schema.output_parser import ( BaseGenerationOutputParser, ) diff --git a/libs/langchain/langchain/output_parsers/prompts.py b/libs/langchain/langchain/output_parsers/prompts.py index 5ea37b24a26..dd06a70c58b 100644 --- a/libs/langchain/langchain/output_parsers/prompts.py +++ b/libs/langchain/langchain/output_parsers/prompts.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts.prompt import PromptTemplate NAIVE_FIX = """Instructions: -------------- diff --git a/libs/langchain/langchain/output_parsers/pydantic.py b/libs/langchain/langchain/output_parsers/pydantic.py index f60408eccff..80fb5e926c3 100644 --- a/libs/langchain/langchain/output_parsers/pydantic.py +++ b/libs/langchain/langchain/output_parsers/pydantic.py @@ -2,9 +2,10 @@ import json import re from typing import Type, TypeVar +from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.schema import BaseOutputParser, OutputParserException + from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS -from langchain.pydantic_v1 import BaseModel, ValidationError -from langchain.schema import BaseOutputParser, OutputParserException T = TypeVar("T", bound=BaseModel) diff --git a/libs/langchain/langchain/output_parsers/rail_parser.py b/libs/langchain/langchain/output_parsers/rail_parser.py index 093b457b91b..64077cf0966 100644 --- a/libs/langchain/langchain/output_parsers/rail_parser.py +++ b/libs/langchain/langchain/output_parsers/rail_parser.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Callable, Dict, Optional -from langchain.schema import BaseOutputParser +from langchain_core.schema import BaseOutputParser class GuardrailsOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/output_parsers/regex.py b/libs/langchain/langchain/output_parsers/regex.py index 5166c09c390..cc66b95a51f 100644 --- a/libs/langchain/langchain/output_parsers/regex.py +++ b/libs/langchain/langchain/output_parsers/regex.py @@ -3,7 +3,7 @@ from __future__ import annotations import re from typing import Dict, List, Optional -from langchain.schema import BaseOutputParser +from langchain_core.schema import BaseOutputParser class RegexParser(BaseOutputParser): diff --git a/libs/langchain/langchain/output_parsers/regex_dict.py b/libs/langchain/langchain/output_parsers/regex_dict.py index a52b8980f3a..9cb31db3665 100644 --- a/libs/langchain/langchain/output_parsers/regex_dict.py +++ b/libs/langchain/langchain/output_parsers/regex_dict.py @@ -3,7 +3,7 @@ from __future__ import annotations import re from typing import Dict, Optional -from langchain.schema import BaseOutputParser +from langchain_core.schema import BaseOutputParser class RegexDictParser(BaseOutputParser): diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index c78f2469b99..e5ec6d29a7a 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -2,14 +2,14 @@ from __future__ import annotations from typing import Any, TypeVar -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import ( +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import ( BaseOutputParser, BasePromptTemplate, OutputParserException, PromptValue, ) -from langchain.schema.language_model import BaseLanguageModel +from langchain_core.schema.language_model import BaseLanguageModel NAIVE_COMPLETION_RETRY = """Prompt: {prompt} diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 75bc1103dfc..24f6177a14d 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -2,13 +2,14 @@ from __future__ import annotations from typing import Any, List +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import BaseOutputParser + from langchain.output_parsers.format_instructions import ( STRUCTURED_FORMAT_INSTRUCTIONS, STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS, ) from langchain.output_parsers.json import parse_and_check_json_markdown -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BaseOutputParser line_template = '\t"{name}": {type} // {description}' diff --git a/libs/langchain/langchain/output_parsers/xml.py b/libs/langchain/langchain/output_parsers/xml.py index 8ffe4167b66..94361f30f9b 100644 --- a/libs/langchain/langchain/output_parsers/xml.py +++ b/libs/langchain/langchain/output_parsers/xml.py @@ -2,8 +2,9 @@ import re import xml.etree.ElementTree as ET from typing import Any, Dict, List, Optional +from langchain_core.schema import BaseOutputParser + from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS -from langchain.schema import BaseOutputParser class XMLOutputParser(BaseOutputParser): diff --git a/libs/langchain/langchain/prompts/__init__.py b/libs/langchain/langchain/prompts/__init__.py index 66c2dbe1223..18484ac266e 100644 --- a/libs/langchain/langchain/prompts/__init__.py +++ b/libs/langchain/langchain/prompts/__init__.py @@ -27,8 +27,8 @@ from multiple components. Prompt classes and functions make constructing ChatPromptValue """ # noqa: E501 -from langchain.prompts.base import StringPromptTemplate -from langchain.prompts.chat import ( +from langchain_core.prompts.base import StringPromptTemplate +from langchain_core.prompts.chat import ( AIMessagePromptTemplate, BaseChatPromptTemplate, ChatMessagePromptTemplate, @@ -37,21 +37,22 @@ from langchain.prompts.chat import ( MessagesPlaceholder, SystemMessagePromptTemplate, ) +from langchain_core.prompts.few_shot import ( + FewShotChatMessagePromptTemplate, + FewShotPromptTemplate, +) +from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates +from langchain_core.prompts.loading import load_prompt +from langchain_core.prompts.pipeline import PipelinePromptTemplate +from langchain_core.prompts.prompt import Prompt, PromptTemplate +from langchain_core.schema.prompt_template import BasePromptTemplate + from langchain.prompts.example_selector import ( LengthBasedExampleSelector, MaxMarginalRelevanceExampleSelector, NGramOverlapExampleSelector, SemanticSimilarityExampleSelector, ) -from langchain.prompts.few_shot import ( - FewShotChatMessagePromptTemplate, - FewShotPromptTemplate, -) -from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates -from langchain.prompts.loading import load_prompt -from langchain.prompts.pipeline import PipelinePromptTemplate -from langchain.prompts.prompt import Prompt, PromptTemplate -from langchain.schema.prompt_template import BasePromptTemplate __all__ = [ "AIMessagePromptTemplate", diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index f2744c43a14..a266d5e4c0b 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -1,173 +1,19 @@ -"""BasePrompt schema definition.""" -from __future__ import annotations +from langchain_core.prompts.base import ( + StringPromptTemplate, + StringPromptValue, + check_valid_template, + get_template_variables, + jinja2_formatter, + validate_jinja2, +) +from langchain_core.schema.prompt_template import BasePromptTemplate -import warnings -from abc import ABC -from string import Formatter -from typing import Any, Callable, Dict, List, Literal, Set - -from langchain.schema.messages import BaseMessage, HumanMessage -from langchain.schema.prompt import PromptValue -from langchain.schema.prompt_template import BasePromptTemplate -from langchain.utils.formatting import formatter - - -def jinja2_formatter(template: str, **kwargs: Any) -> str: - """Format a template using jinja2. - - *Security warning*: As of LangChain 0.0.329, this method uses Jinja2's - SandboxedEnvironment by default. However, this sand-boxing should - be treated as a best-effort approach rather than a guarantee of security. - Do not accept jinja2 templates from untrusted sources as they may lead - to arbitrary Python code execution. - - https://jinja.palletsprojects.com/en/3.1.x/sandbox/ - """ - try: - from jinja2.sandbox import SandboxedEnvironment - except ImportError: - raise ImportError( - "jinja2 not installed, which is needed to use the jinja2_formatter. " - "Please install it with `pip install jinja2`." - "Please be cautious when using jinja2 templates. " - "Do not expand jinja2 templates using unverified or user-controlled " - "inputs as that can result in arbitrary Python code execution." - ) - - # This uses a sandboxed environment to prevent arbitrary code execution. - # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. - # Please treat this sand-boxing as a best-effort approach rather than - # a guarantee of security. - # We recommend to never use jinja2 templates with untrusted inputs. - # https://jinja.palletsprojects.com/en/3.1.x/sandbox/ - # approach not a guarantee of security. - return SandboxedEnvironment().from_string(template).render(**kwargs) - - -def validate_jinja2(template: str, input_variables: List[str]) -> None: - """ - Validate that the input variables are valid for the template. - Issues a warning if missing or extra variables are found. - - Args: - template: The template string. - input_variables: The input variables. - """ - input_variables_set = set(input_variables) - valid_variables = _get_jinja2_variables_from_template(template) - missing_variables = valid_variables - input_variables_set - extra_variables = input_variables_set - valid_variables - - warning_message = "" - if missing_variables: - warning_message += f"Missing variables: {missing_variables} " - - if extra_variables: - warning_message += f"Extra variables: {extra_variables}" - - if warning_message: - warnings.warn(warning_message.strip()) - - -def _get_jinja2_variables_from_template(template: str) -> Set[str]: - try: - from jinja2 import Environment, meta - except ImportError: - raise ImportError( - "jinja2 not installed, which is needed to use the jinja2_formatter. " - "Please install it with `pip install jinja2`." - ) - env = Environment() - ast = env.parse(template) - variables = meta.find_undeclared_variables(ast) - return variables - - -DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { - "f-string": formatter.format, - "jinja2": jinja2_formatter, -} - -DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { - "f-string": formatter.validate_input_variables, - "jinja2": validate_jinja2, -} - - -def check_valid_template( - template: str, template_format: str, input_variables: List[str] -) -> None: - """Check that template string is valid. - - Args: - template: The template string. - template_format: The template format. Should be one of "f-string" or "jinja2". - input_variables: The input variables. - - Raises: - ValueError: If the template format is not supported. - """ - if template_format not in DEFAULT_FORMATTER_MAPPING: - valid_formats = list(DEFAULT_FORMATTER_MAPPING) - raise ValueError( - f"Invalid template format. Got `{template_format}`;" - f" should be one of {valid_formats}" - ) - try: - validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] - validator_func(template, input_variables) - except KeyError as e: - raise ValueError( - "Invalid prompt schema; check for mismatched or missing input parameters. " - + str(e) - ) - - -def get_template_variables(template: str, template_format: str) -> List[str]: - """Get the variables from the template. - - Args: - template: The template string. - template_format: The template format. Should be one of "f-string" or "jinja2". - - Returns: - The variables from the template. - - Raises: - ValueError: If the template format is not supported. - """ - if template_format == "jinja2": - # Get the variables for the template - input_variables = _get_jinja2_variables_from_template(template) - elif template_format == "f-string": - input_variables = { - v for _, v, _, _ in Formatter().parse(template) if v is not None - } - else: - raise ValueError(f"Unsupported template format: {template_format}") - - return sorted(input_variables) - - -class StringPromptValue(PromptValue): - """String prompt value.""" - - text: str - """Prompt text.""" - type: Literal["StringPromptValue"] = "StringPromptValue" - - def to_string(self) -> str: - """Return prompt as string.""" - return self.text - - def to_messages(self) -> List[BaseMessage]: - """Return prompt as messages.""" - return [HumanMessage(content=self.text)] - - -class StringPromptTemplate(BasePromptTemplate, ABC): - """String prompt that exposes the format method, returning a prompt.""" - - def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" - return StringPromptValue(text=self.format(**kwargs)) +__all__ = [ + "jinja2_formatter", + "validate_jinja2", + "check_valid_template", + "get_template_variables", + "StringPromptValue", + "StringPromptTemplate", + "BasePromptTemplate", +] diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index e8bdd771400..fc965d11204 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -1,748 +1,27 @@ -"""Chat prompt template.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - overload, +from langchain_core.prompts.chat import ( + AIMessagePromptTemplate, + BaseChatPromptTemplate, + BaseMessagePromptTemplate, + BaseStringMessagePromptTemplate, + ChatMessagePromptTemplate, + ChatPromptTemplate, + ChatPromptValue, + ChatPromptValueConcrete, + HumanMessagePromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, ) -from langchain._api import deprecated -from langchain.load.serializable import Serializable -from langchain.prompts.base import StringPromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import ( - BasePromptTemplate, - PromptValue, -) -from langchain.schema.messages import ( - AIMessage, - AnyMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, - get_buffer_string, -) - - -class BaseMessagePromptTemplate(Serializable, ABC): - """Base class for message prompt templates.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable.""" - return True - - @abstractmethod - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format messages from kwargs. Should return a list of BaseMessages. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - List of BaseMessages. - """ - - @property - @abstractmethod - def input_variables(self) -> List[str]: - """Input variables for this prompt template. - - Returns: - List of input variables. - """ - - def __add__(self, other: Any) -> ChatPromptTemplate: - """Combine two prompt templates. - - Args: - other: Another prompt template. - - Returns: - Combined prompt template. - """ - prompt = ChatPromptTemplate(messages=[self]) - return prompt + other - - -class MessagesPlaceholder(BaseMessagePromptTemplate): - """Prompt template that assumes variable is already list of messages.""" - - variable_name: str - """Name of variable to use as messages.""" - - def __init__(self, variable_name: str, **kwargs: Any): - return super().__init__(variable_name=variable_name, **kwargs) - - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format messages from kwargs. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - List of BaseMessage. - """ - value = kwargs[self.variable_name] - if not isinstance(value, list): - raise ValueError( - f"variable {self.variable_name} should be a list of base messages, " - f"got {value}" - ) - for v in value: - if not isinstance(v, BaseMessage): - raise ValueError( - f"variable {self.variable_name} should be a list of base messages," - f" got {value}" - ) - return value - - @property - def input_variables(self) -> List[str]: - """Input variables for this prompt template. - - Returns: - List of input variable names. - """ - return [self.variable_name] - - -MessagePromptTemplateT = TypeVar( - "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" -) -"""Type variable for message prompt templates.""" - - -class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): - """Base class for message prompt templates that use a string prompt template.""" - - prompt: StringPromptTemplate - """String prompt template.""" - additional_kwargs: dict = Field(default_factory=dict) - """Additional keyword arguments to pass to the prompt template.""" - - @classmethod - def from_template( - cls: Type[MessagePromptTemplateT], - template: str, - template_format: str = "f-string", - **kwargs: Any, - ) -> MessagePromptTemplateT: - """Create a class from a string template. - - Args: - template: a template. - template_format: format of the template. - **kwargs: keyword arguments to pass to the constructor. - - Returns: - A new instance of this class. - """ - prompt = PromptTemplate.from_template(template, template_format=template_format) - return cls(prompt=prompt, **kwargs) - - @classmethod - def from_template_file( - cls: Type[MessagePromptTemplateT], - template_file: Union[str, Path], - input_variables: List[str], - **kwargs: Any, - ) -> MessagePromptTemplateT: - """Create a class from a template file. - - Args: - template_file: path to a template file. String or Path. - input_variables: list of input variables. - **kwargs: keyword arguments to pass to the constructor. - - Returns: - A new instance of this class. - """ - prompt = PromptTemplate.from_file(template_file, input_variables) - return cls(prompt=prompt, **kwargs) - - @abstractmethod - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - Formatted message. - """ - - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format messages from kwargs. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - List of BaseMessages. - """ - return [self.format(**kwargs)] - - @property - def input_variables(self) -> List[str]: - """ - Input variables for this prompt template. - - Returns: - List of input variable names. - """ - return self.prompt.input_variables - - -class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): - """Chat message prompt template.""" - - role: str - """Role of the message.""" - - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - Formatted message. - """ - text = self.prompt.format(**kwargs) - return ChatMessage( - content=text, role=self.role, additional_kwargs=self.additional_kwargs - ) - - -class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): - """Human message prompt template. This is a message sent from the user.""" - - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - Formatted message. - """ - text = self.prompt.format(**kwargs) - return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) - - -class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): - """AI message prompt template. This is a message sent from the AI.""" - - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - Formatted message. - """ - text = self.prompt.format(**kwargs) - return AIMessage(content=text, additional_kwargs=self.additional_kwargs) - - -class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): - """System message prompt template. - This is a message that is not sent to the user. - """ - - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - Formatted message. - """ - text = self.prompt.format(**kwargs) - return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) - - -class ChatPromptValue(PromptValue): - """Chat prompt value. - - A type of a prompt value that is built from messages. - """ - - messages: Sequence[BaseMessage] - """List of messages.""" - - def to_string(self) -> str: - """Return prompt as string.""" - return get_buffer_string(self.messages) - - def to_messages(self) -> List[BaseMessage]: - """Return prompt as a list of messages.""" - return list(self.messages) - - -class ChatPromptValueConcrete(ChatPromptValue): - """Chat prompt value which explicitly lists out the message types it accepts. - For use in external schemas.""" - - messages: Sequence[AnyMessage] - - type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" - - -class BaseChatPromptTemplate(BasePromptTemplate, ABC): - """Base class for chat prompt templates.""" - - @property - def lc_attributes(self) -> Dict: - """ - Return a list of attribute names that should be included in the - serialized kwargs. These attributes must be accepted by the - constructor. - """ - return {"input_variables": self.input_variables} - - def format(self, **kwargs: Any) -> str: - """Format the chat template into a string. - - Args: - **kwargs: keyword arguments to use for filling in template variables - in all the template messages in this chat template. - - Returns: - formatted string - """ - return self.format_prompt(**kwargs).to_string() - - def format_prompt(self, **kwargs: Any) -> PromptValue: - """ - Format prompt. Should return a PromptValue. - Args: - **kwargs: Keyword arguments to use for formatting. - - Returns: - PromptValue. - """ - messages = self.format_messages(**kwargs) - return ChatPromptValue(messages=messages) - - @abstractmethod - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format kwargs into a list of messages.""" - - -MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] - -MessageLikeRepresentation = Union[ - MessageLike, - Tuple[str, str], - Tuple[Type, str], - str, +__all__ = [ + "BaseMessagePromptTemplate", + "MessagesPlaceholder", + "BaseStringMessagePromptTemplate", + "ChatMessagePromptTemplate", + "HumanMessagePromptTemplate", + "AIMessagePromptTemplate", + "SystemMessagePromptTemplate", + "ChatPromptValue", + "ChatPromptValueConcrete", + "BaseChatPromptTemplate", + "ChatPromptTemplate", ] - - -class ChatPromptTemplate(BaseChatPromptTemplate): - """A prompt template for chat models. - - Use to create flexible templated prompts for chat models. - - Examples: - - .. code-block:: python - - from langchain.prompts import ChatPromptTemplate - - template = ChatPromptTemplate.from_messages([ - ("system", "You are a helpful AI bot. Your name is {name}."), - ("human", "Hello, how are you doing?"), - ("ai", "I'm doing well, thanks!"), - ("human", "{user_input}"), - ]) - - messages = template.format_messages( - name="Bob", - user_input="What is your name?" - ) - """ - - input_variables: List[str] - """List of input variables in template messages. Used for validation.""" - messages: List[MessageLike] - """List of messages consisting of either message prompt templates or messages.""" - validate_template: bool = False - """Whether or not to try validating the template.""" - - def __add__(self, other: Any) -> ChatPromptTemplate: - """Combine two prompt templates. - - Args: - other: Another prompt template. - - Returns: - Combined prompt template. - """ - # Allow for easy combining - if isinstance(other, ChatPromptTemplate): - return ChatPromptTemplate(messages=self.messages + other.messages) - elif isinstance( - other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate) - ): - return ChatPromptTemplate(messages=self.messages + [other]) - elif isinstance(other, (list, tuple)): - _other = ChatPromptTemplate.from_messages(other) - return ChatPromptTemplate(messages=self.messages + _other.messages) - elif isinstance(other, str): - prompt = HumanMessagePromptTemplate.from_template(other) - return ChatPromptTemplate(messages=self.messages + [prompt]) - else: - raise NotImplementedError(f"Unsupported operand type for +: {type(other)}") - - @root_validator(pre=True) - def validate_input_variables(cls, values: dict) -> dict: - """Validate input variables. - - If input_variables is not set, it will be set to the union of - all input variables in the messages. - - Args: - values: values to validate. - - Returns: - Validated values. - """ - messages = values["messages"] - input_vars = set() - input_types: Dict[str, Any] = values.get("input_types", {}) - for message in messages: - if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): - input_vars.update(message.input_variables) - if isinstance(message, MessagesPlaceholder): - if message.variable_name not in input_types: - input_types[message.variable_name] = List[AnyMessage] - if "partial_variables" in values: - input_vars = input_vars - set(values["partial_variables"]) - if "input_variables" in values and values.get("validate_template"): - if input_vars != set(values["input_variables"]): - raise ValueError( - "Got mismatched input_variables. " - f"Expected: {input_vars}. " - f"Got: {values['input_variables']}" - ) - else: - values["input_variables"] = sorted(input_vars) - values["input_types"] = input_types - return values - - @classmethod - def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: - """Create a chat prompt template from a template string. - - Creates a chat template consisting of a single message assumed to be from - the human. - - Args: - template: template string - **kwargs: keyword arguments to pass to the constructor. - - Returns: - A new instance of this class. - """ - prompt_template = PromptTemplate.from_template(template, **kwargs) - message = HumanMessagePromptTemplate(prompt=prompt_template) - return cls.from_messages([message]) - - @classmethod - @deprecated("0.0.260", alternative="from_messages classmethod", pending=True) - def from_role_strings( - cls, string_messages: List[Tuple[str, str]] - ) -> ChatPromptTemplate: - """Create a chat prompt template from a list of (role, template) tuples. - - Args: - string_messages: list of (role, template) tuples. - - Returns: - a chat prompt template - """ - return cls( - messages=[ - ChatMessagePromptTemplate.from_template(template, role=role) - for role, template in string_messages - ] - ) - - @classmethod - @deprecated("0.0.260", alternative="from_messages classmethod", pending=True) - def from_strings( - cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] - ) -> ChatPromptTemplate: - """Create a chat prompt template from a list of (role class, template) tuples. - - Args: - string_messages: list of (role class, template) tuples. - - Returns: - a chat prompt template - """ - return cls.from_messages(string_messages) - - @classmethod - def from_messages( - cls, - messages: Sequence[MessageLikeRepresentation], - ) -> ChatPromptTemplate: - """Create a chat prompt template from a variety of message formats. - - Examples: - - Instantiation from a list of message templates: - - .. code-block:: python - - template = ChatPromptTemplate.from_messages([ - ("human", "Hello, how are you?"), - ("ai", "I'm doing well, thanks!"), - ("human", "That's good to hear."), - ]) - - Instantiation from mixed message formats: - - .. code-block:: python - - template = ChatPromptTemplate.from_messages([ - SystemMessage(content="hello"), - ("human", "Hello, how are you?"), - ]) - - Args: - messages: sequence of message representations. - A message can be represented using the following formats: - (1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of - (message type, template); e.g., ("human", "{user_input}"), - (4) 2-tuple of (message class, template), (4) a string which is - shorthand for ("human", template); e.g., "{user_input}" - - Returns: - a chat prompt template - """ - _messages = [_convert_to_message(message) for message in messages] - - # Automatically infer input variables from messages - input_vars: Set[str] = set() - for _message in _messages: - if isinstance( - _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) - ): - input_vars.update(_message.input_variables) - - return cls(input_variables=sorted(input_vars), messages=_messages) - - def format(self, **kwargs: Any) -> str: - """Format the chat template into a string. - - Args: - **kwargs: keyword arguments to use for filling in template variables - in all the template messages in this chat template. - - Returns: - formatted string - """ - return self.format_prompt(**kwargs).to_string() - - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format the chat template into a list of finalized messages. - - Args: - **kwargs: keyword arguments to use for filling in template variables - in all the template messages in this chat template. - - Returns: - list of formatted messages - """ - kwargs = self._merge_partial_and_user_variables(**kwargs) - result = [] - for message_template in self.messages: - if isinstance(message_template, BaseMessage): - result.extend([message_template]) - elif isinstance( - message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) - ): - rel_params = { - k: v - for k, v in kwargs.items() - if k in message_template.input_variables - } - message = message_template.format_messages(**rel_params) - result.extend(message) - else: - raise ValueError(f"Unexpected input: {message_template}") - return result - - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: - """Get a new ChatPromptTemplate with some input variables already filled in. - - Args: - **kwargs: keyword arguments to use for filling in template variables. Ought - to be a subset of the input variables. - - Returns: - A new ChatPromptTemplate. - - - Example: - - .. code-block:: python - - from langchain.prompts import ChatPromptTemplate - - template = ChatPromptTemplate.from_messages( - [ - ("system", "You are an AI assistant named {name}."), - ("human", "Hi I'm {user}"), - ("ai", "Hi there, {user}, I'm {name}."), - ("human", "{input}"), - ] - ) - template2 = template.partial(user="Lucy", name="R2D2") - - template2.format_messages(input="hello") - """ - prompt_dict = self.__dict__.copy() - prompt_dict["input_variables"] = list( - set(self.input_variables).difference(kwargs) - ) - prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} - return type(self)(**prompt_dict) - - def append(self, message: MessageLikeRepresentation) -> None: - """Append message to the end of the chat template. - - Args: - message: representation of a message to append. - """ - self.messages.append(_convert_to_message(message)) - - def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None: - """Extend the chat template with a sequence of messages.""" - self.messages.extend([_convert_to_message(message) for message in messages]) - - @overload - def __getitem__(self, index: int) -> MessageLike: - ... - - @overload - def __getitem__(self, index: slice) -> ChatPromptTemplate: - ... - - def __getitem__( - self, index: Union[int, slice] - ) -> Union[MessageLike, ChatPromptTemplate]: - """Use to index into the chat template.""" - if isinstance(index, slice): - start, stop, step = index.indices(len(self.messages)) - messages = self.messages[start:stop:step] - return ChatPromptTemplate.from_messages(messages) - else: - return self.messages[index] - - def __len__(self) -> int: - """Get the length of the chat template.""" - return len(self.messages) - - @property - def _prompt_type(self) -> str: - """Name of prompt type.""" - return "chat" - - def save(self, file_path: Union[Path, str]) -> None: - """Save prompt to file. - - Args: - file_path: path to file. - """ - raise NotImplementedError() - - -def _create_template_from_message_type( - message_type: str, template: str -) -> BaseMessagePromptTemplate: - """Create a message prompt template from a message type and template string. - - Args: - message_type: str the type of the message template (e.g., "human", "ai", etc.) - template: str the template string. - - Returns: - a message prompt template of the appropriate type. - """ - if message_type in ("human", "user"): - message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template( - template - ) - elif message_type in ("ai", "assistant"): - message = AIMessagePromptTemplate.from_template(template) - elif message_type == "system": - message = SystemMessagePromptTemplate.from_template(template) - else: - raise ValueError( - f"Unexpected message type: {message_type}. Use one of 'human'," - f" 'user', 'ai', 'assistant', or 'system'." - ) - return message - - -def _convert_to_message( - message: MessageLikeRepresentation, -) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: - """Instantiate a message from a variety of message formats. - - The message format can be one of the following: - - - BaseMessagePromptTemplate - - BaseMessage - - 2-tuple of (role string, template); e.g., ("human", "{user_input}") - - 2-tuple of (message class, template) - - string: shorthand for ("human", template); e.g., "{user_input}" - - Args: - message: a representation of a message in one of the supported formats - - Returns: - an instance of a message or a message template - """ - if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): - _message: Union[ - BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate - ] = message - elif isinstance(message, BaseMessage): - _message = message - elif isinstance(message, str): - _message = _create_template_from_message_type("human", message) - elif isinstance(message, tuple): - if len(message) != 2: - raise ValueError(f"Expected 2-tuple of (role, template), got {message}") - message_type_str, template = message - if isinstance(message_type_str, str): - _message = _create_template_from_message_type(message_type_str, template) - else: - _message = message_type_str(prompt=PromptTemplate.from_template(template)) - else: - raise NotImplementedError(f"Unsupported message type: {type(message)}") - - return _message diff --git a/libs/langchain/langchain/prompts/example_selector/__init__.py b/libs/langchain/langchain/prompts/example_selector/__init__.py index fedc5db79f5..7cb71659c62 100644 --- a/libs/langchain/langchain/prompts/example_selector/__init__.py +++ b/libs/langchain/langchain/prompts/example_selector/__init__.py @@ -1,11 +1,16 @@ """Logic for selecting examples to include in prompts.""" -from langchain.prompts.example_selector.length_based import LengthBasedExampleSelector -from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector -from langchain.prompts.example_selector.semantic_similarity import ( +from langchain_core.prompts.example_selector.length_based import ( + LengthBasedExampleSelector, +) +from langchain_core.prompts.example_selector.semantic_similarity import ( MaxMarginalRelevanceExampleSelector, SemanticSimilarityExampleSelector, ) +from langchain.prompts.example_selector.ngram_overlap import ( + NGramOverlapExampleSelector, +) + __all__ = [ "LengthBasedExampleSelector", "MaxMarginalRelevanceExampleSelector", diff --git a/libs/langchain/langchain/prompts/example_selector/base.py b/libs/langchain/langchain/prompts/example_selector/base.py index ff2e099c810..3649ca63e67 100644 --- a/libs/langchain/langchain/prompts/example_selector/base.py +++ b/libs/langchain/langchain/prompts/example_selector/base.py @@ -1,15 +1,3 @@ -"""Interface for selecting examples to include in prompts.""" -from abc import ABC, abstractmethod -from typing import Any, Dict, List +from langchain_core.prompts.example_selector.base import BaseExampleSelector - -class BaseExampleSelector(ABC): - """Interface for selecting examples to include in prompts.""" - - @abstractmethod - def add_example(self, example: Dict[str, str]) -> Any: - """Add new example to store for a key.""" - - @abstractmethod - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on the inputs.""" +__all__ = ["BaseExampleSelector"] diff --git a/libs/langchain/langchain/prompts/example_selector/length_based.py b/libs/langchain/langchain/prompts/example_selector/length_based.py index 19f4a0c5c8e..e9edb8fcff4 100644 --- a/libs/langchain/langchain/prompts/example_selector/length_based.py +++ b/libs/langchain/langchain/prompts/example_selector/length_based.py @@ -1,63 +1,5 @@ -"""Select examples based on length.""" -import re -from typing import Callable, Dict, List +from langchain_core.prompts.example_selector.length_based import ( + LengthBasedExampleSelector, +) -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import BaseModel, validator - - -def _get_length_based(text: str) -> int: - return len(re.split("\n| ", text)) - - -class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): - """Select examples based on length.""" - - examples: List[dict] - """A list of the examples that the prompt template expects.""" - - example_prompt: PromptTemplate - """Prompt template used to format the examples.""" - - get_text_length: Callable[[str], int] = _get_length_based - """Function to measure prompt length. Defaults to word count.""" - - max_length: int = 2048 - """Max length for the prompt, beyond which examples are cut.""" - - example_text_lengths: List[int] = [] #: :meta private: - - def add_example(self, example: Dict[str, str]) -> None: - """Add new example to list.""" - self.examples.append(example) - string_example = self.example_prompt.format(**example) - self.example_text_lengths.append(self.get_text_length(string_example)) - - @validator("example_text_lengths", always=True) - def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: - """Calculate text lengths if they don't exist.""" - # Check if text lengths were passed in - if v: - return v - # If they were not, calculate them - example_prompt = values["example_prompt"] - get_text_length = values["get_text_length"] - string_examples = [example_prompt.format(**eg) for eg in values["examples"]] - return [get_text_length(eg) for eg in string_examples] - - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on the input lengths.""" - inputs = " ".join(input_variables.values()) - remaining_length = self.max_length - self.get_text_length(inputs) - i = 0 - examples = [] - while remaining_length > 0 and i < len(self.examples): - new_length = remaining_length - self.example_text_lengths[i] - if new_length < 0: - break - else: - examples.append(self.examples[i]) - remaining_length = new_length - i += 1 - return examples +__all__ = ["LengthBasedExampleSelector"] diff --git a/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py b/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py index 39147e1ff6c..21bfa4c411d 100644 --- a/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py +++ b/libs/langchain/langchain/prompts/example_selector/ngram_overlap.py @@ -6,10 +6,9 @@ https://aclanthology.org/P02-1040.pdf from typing import Dict, List import numpy as np - -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import BaseModel, root_validator +from langchain_core.prompts.example_selector.base import BaseExampleSelector +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, root_validator def ngram_overlap_score(source: List[str], example: List[str]) -> float: diff --git a/libs/langchain/langchain/prompts/example_selector/semantic_similarity.py b/libs/langchain/langchain/prompts/example_selector/semantic_similarity.py index 4548b3b2878..2f730f1895f 100644 --- a/libs/langchain/langchain/prompts/example_selector/semantic_similarity.py +++ b/libs/langchain/langchain/prompts/example_selector/semantic_similarity.py @@ -1,165 +1,11 @@ -"""Example selector that selects examples based on SemanticSimilarity.""" -from __future__ import annotations +from langchain_core.prompts.example_selector.semantic_similarity import ( + MaxMarginalRelevanceExampleSelector, + SemanticSimilarityExampleSelector, + sorted_values, +) -from typing import Any, Dict, List, Optional, Type - -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.pydantic_v1 import BaseModel, Extra -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore - - -def sorted_values(values: Dict[str, str]) -> List[Any]: - """Return a list of values in dict sorted by key.""" - return [values[val] for val in sorted(values)] - - -class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): - """Example selector that selects examples based on SemanticSimilarity.""" - - vectorstore: VectorStore - """VectorStore than contains information about examples.""" - k: int = 4 - """Number of examples to select.""" - example_keys: Optional[List[str]] = None - """Optional keys to filter examples to.""" - input_keys: Optional[List[str]] = None - """Optional keys to filter input to. If provided, the search is based on - the input variables instead of all variables.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def add_example(self, example: Dict[str, str]) -> str: - """Add new example to vectorstore.""" - if self.input_keys: - string_example = " ".join( - sorted_values({key: example[key] for key in self.input_keys}) - ) - else: - string_example = " ".join(sorted_values(example)) - ids = self.vectorstore.add_texts([string_example], metadatas=[example]) - return ids[0] - - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on semantic similarity.""" - # Get the docs with the highest similarity. - if self.input_keys: - input_variables = {key: input_variables[key] for key in self.input_keys} - query = " ".join(sorted_values(input_variables)) - example_docs = self.vectorstore.similarity_search(query, k=self.k) - # Get the examples from the metadata. - # This assumes that examples are stored in metadata. - examples = [dict(e.metadata) for e in example_docs] - # If example keys are provided, filter examples to those keys. - if self.example_keys: - examples = [{k: eg[k] for k in self.example_keys} for eg in examples] - return examples - - @classmethod - def from_examples( - cls, - examples: List[dict], - embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], - k: int = 4, - input_keys: Optional[List[str]] = None, - **vectorstore_cls_kwargs: Any, - ) -> SemanticSimilarityExampleSelector: - """Create k-shot example selector using example list and embeddings. - - Reshuffles examples dynamically based on query similarity. - - Args: - examples: List of examples to use in the prompt. - embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings(). - vectorstore_cls: A vector store DB interface class, e.g. FAISS. - k: Number of examples to select - input_keys: If provided, the search is based on the input variables - instead of all variables. - vectorstore_cls_kwargs: optional kwargs containing url for vector store - - Returns: - The ExampleSelector instantiated, backed by a vector store. - """ - if input_keys: - string_examples = [ - " ".join(sorted_values({k: eg[k] for k in input_keys})) - for eg in examples - ] - else: - string_examples = [" ".join(sorted_values(eg)) for eg in examples] - vectorstore = vectorstore_cls.from_texts( - string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs - ) - return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) - - -class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): - """ExampleSelector that selects examples based on Max Marginal Relevance. - - This was shown to improve performance in this paper: - https://arxiv.org/pdf/2211.13892.pdf - """ - - fetch_k: int = 20 - """Number of examples to fetch to rerank.""" - - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on semantic similarity.""" - # Get the docs with the highest similarity. - if self.input_keys: - input_variables = {key: input_variables[key] for key in self.input_keys} - query = " ".join(sorted_values(input_variables)) - example_docs = self.vectorstore.max_marginal_relevance_search( - query, k=self.k, fetch_k=self.fetch_k - ) - # Get the examples from the metadata. - # This assumes that examples are stored in metadata. - examples = [dict(e.metadata) for e in example_docs] - # If example keys are provided, filter examples to those keys. - if self.example_keys: - examples = [{k: eg[k] for k in self.example_keys} for eg in examples] - return examples - - @classmethod - def from_examples( - cls, - examples: List[dict], - embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], - k: int = 4, - input_keys: Optional[List[str]] = None, - fetch_k: int = 20, - **vectorstore_cls_kwargs: Any, - ) -> MaxMarginalRelevanceExampleSelector: - """Create k-shot example selector using example list and embeddings. - - Reshuffles examples dynamically based on query similarity. - - Args: - examples: List of examples to use in the prompt. - embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). - vectorstore_cls: A vector store DB interface class, e.g. FAISS. - k: Number of examples to select - input_keys: If provided, the search is based on the input variables - instead of all variables. - vectorstore_cls_kwargs: optional kwargs containing url for vector store - - Returns: - The ExampleSelector instantiated, backed by a vector store. - """ - if input_keys: - string_examples = [ - " ".join(sorted_values({k: eg[k] for k in input_keys})) - for eg in examples - ] - else: - string_examples = [" ".join(sorted_values(eg)) for eg in examples] - vectorstore = vectorstore_cls.from_texts( - string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs - ) - return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys) +__all__ = [ + "sorted_values", + "SemanticSimilarityExampleSelector", + "MaxMarginalRelevanceExampleSelector", +] diff --git a/libs/langchain/langchain/prompts/few_shot.py b/libs/langchain/langchain/prompts/few_shot.py index 4016336d70f..ab8e24098ed 100644 --- a/libs/langchain/langchain/prompts/few_shot.py +++ b/libs/langchain/langchain/prompts/few_shot.py @@ -1,340 +1,6 @@ -"""Prompt template that contains few shot examples.""" -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -from langchain.prompts.base import ( - DEFAULT_FORMATTER_MAPPING, - StringPromptTemplate, - check_valid_template, - get_template_variables, +from langchain_core.prompts.few_shot import ( + FewShotChatMessagePromptTemplate, + FewShotPromptTemplate, ) -from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.messages import BaseMessage, get_buffer_string - -class _FewShotPromptTemplateMixin(BaseModel): - """Prompt template that contains few shot examples.""" - - examples: Optional[List[dict]] = None - """Examples to format into the prompt. - Either this or example_selector should be provided.""" - - example_selector: Optional[BaseExampleSelector] = None - """ExampleSelector to choose the examples to format into the prompt. - Either this or examples should be provided.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def check_examples_and_selector(cls, values: Dict) -> Dict: - """Check that one and only one of examples/example_selector are provided.""" - examples = values.get("examples", None) - example_selector = values.get("example_selector", None) - if examples and example_selector: - raise ValueError( - "Only one of 'examples' and 'example_selector' should be provided" - ) - - if examples is None and example_selector is None: - raise ValueError( - "One of 'examples' and 'example_selector' should be provided" - ) - - return values - - def _get_examples(self, **kwargs: Any) -> List[dict]: - """Get the examples to use for formatting the prompt. - - Args: - **kwargs: Keyword arguments to be passed to the example selector. - - Returns: - List of examples. - """ - if self.examples is not None: - return self.examples - elif self.example_selector is not None: - return self.example_selector.select_examples(kwargs) - else: - raise ValueError( - "One of 'examples' and 'example_selector' should be provided" - ) - - -class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): - """Prompt template that contains few shot examples.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable.""" - return False - - validate_template: bool = False - """Whether or not to try validating the template.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - example_prompt: PromptTemplate - """PromptTemplate used to format an individual example.""" - - suffix: str - """A prompt template string to put after the examples.""" - - example_separator: str = "\n\n" - """String separator used to join the prefix, the examples, and suffix.""" - - prefix: str = "" - """A prompt template string to put before the examples.""" - - template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that prefix, suffix, and input variables are consistent.""" - if values["validate_template"]: - check_valid_template( - values["prefix"] + values["suffix"], - values["template_format"], - values["input_variables"] + list(values["partial_variables"]), - ) - elif values.get("template_format"): - values["input_variables"] = [ - var - for var in get_template_variables( - values["prefix"] + values["suffix"], values["template_format"] - ) - if var not in values["partial_variables"] - ] - return values - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - **kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - kwargs = self._merge_partial_and_user_variables(**kwargs) - # Get the examples to use. - examples = self._get_examples(**kwargs) - examples = [ - {k: e[k] for k in self.example_prompt.input_variables} for e in examples - ] - # Format the examples. - example_strings = [ - self.example_prompt.format(**example) for example in examples - ] - # Create the overall template. - pieces = [self.prefix, *example_strings, self.suffix] - template = self.example_separator.join([piece for piece in pieces if piece]) - - # Format the template with the input variables. - return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - return "few_shot" - - def save(self, file_path: Union[Path, str]) -> None: - if self.example_selector: - raise ValueError("Saving an example selector is not currently supported") - return super().save(file_path) - - -class FewShotChatMessagePromptTemplate( - BaseChatPromptTemplate, _FewShotPromptTemplateMixin -): - """Chat prompt template that supports few-shot examples. - - The high level structure of produced by this prompt template is a list of messages - consisting of prefix message(s), example message(s), and suffix message(s). - - This structure enables creating a conversation with intermediate examples like: - - System: You are a helpful AI Assistant - Human: What is 2+2? - AI: 4 - Human: What is 2+3? - AI: 5 - Human: What is 4+4? - - This prompt template can be used to generate a fixed list of examples or else - to dynamically select examples based on the input. - - Examples: - - Prompt template with a fixed list of examples (matching the sample - conversation above): - - .. code-block:: python - - from langchain.prompts import ( - FewShotChatMessagePromptTemplate, - ChatPromptTemplate - ) - - examples = [ - {"input": "2+2", "output": "4"}, - {"input": "2+3", "output": "5"}, - ] - - example_prompt = ChatPromptTemplate.from_messages( - [('human', '{input}'), ('ai', '{output}')] - ) - - few_shot_prompt = FewShotChatMessagePromptTemplate( - examples=examples, - # This is a prompt template used to format each individual example. - example_prompt=example_prompt, - ) - - final_prompt = ChatPromptTemplate.from_messages( - [ - ('system', 'You are a helpful AI Assistant'), - few_shot_prompt, - ('human', '{input}'), - ] - ) - final_prompt.format(input="What is 4+4?") - - Prompt template with dynamically selected examples: - - .. code-block:: python - - from langchain.prompts import SemanticSimilarityExampleSelector - from langchain.embeddings import OpenAIEmbeddings - from langchain.vectorstores import Chroma - - examples = [ - {"input": "2+2", "output": "4"}, - {"input": "2+3", "output": "5"}, - {"input": "2+4", "output": "6"}, - # ... - ] - - to_vectorize = [ - " ".join(example.values()) - for example in examples - ] - embeddings = OpenAIEmbeddings() - vectorstore = Chroma.from_texts( - to_vectorize, embeddings, metadatas=examples - ) - example_selector = SemanticSimilarityExampleSelector( - vectorstore=vectorstore - ) - - from langchain.schema import SystemMessage - from langchain.prompts import HumanMessagePromptTemplate - from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate - - few_shot_prompt = FewShotChatMessagePromptTemplate( - # Which variable(s) will be passed to the example selector. - input_variables=["input"], - example_selector=example_selector, - # Define how each example will be formatted. - # In this case, each example will become 2 messages: - # 1 human, and 1 AI - example_prompt=( - HumanMessagePromptTemplate.from_template("{input}") - + AIMessagePromptTemplate.from_template("{output}") - ), - ) - # Define the overall prompt. - final_prompt = ( - SystemMessagePromptTemplate.from_template( - "You are a helpful AI Assistant" - ) - + few_shot_prompt - + HumanMessagePromptTemplate.from_template("{input}") - ) - # Show the prompt - print(final_prompt.format_messages(input="What's 3+3?")) - - # Use within an LLM - from langchain.chat_models import ChatAnthropic - chain = final_prompt | ChatAnthropic() - chain.invoke({"input": "What's 3+3?"}) - """ - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable.""" - return False - - input_variables: List[str] = Field(default_factory=list) - """A list of the names of the variables the prompt template will use - to pass to the example_selector, if provided.""" - example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] - """The class to format each example.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format kwargs into a list of messages. - - Args: - **kwargs: keyword arguments to use for filling in templates in messages. - - Returns: - A list of formatted messages with all template variables filled in. - """ - # Get the examples to use. - examples = self._get_examples(**kwargs) - examples = [ - {k: e[k] for k in self.example_prompt.input_variables} for e in examples - ] - # Format the examples. - messages = [ - message - for example in examples - for message in self.example_prompt.format_messages(**example) - ] - return messages - - def format(self, **kwargs: Any) -> str: - """Format the prompt with inputs generating a string. - - Use this method to generate a string representation of a prompt consisting - of chat messages. - - Useful for feeding into a string based completion language model or debugging. - - Args: - **kwargs: keyword arguments to use for formatting. - - Returns: - A string representation of the prompt - """ - messages = self.format_messages(**kwargs) - return get_buffer_string(messages) +__all__ = ["FewShotPromptTemplate", "FewShotChatMessagePromptTemplate"] diff --git a/libs/langchain/langchain/prompts/few_shot_with_templates.py b/libs/langchain/langchain/prompts/few_shot_with_templates.py index bee5fe71e03..7e530dbe9cb 100644 --- a/libs/langchain/langchain/prompts/few_shot_with_templates.py +++ b/libs/langchain/langchain/prompts/few_shot_with_templates.py @@ -1,153 +1,3 @@ -"""Prompt template that contains few shot examples.""" -from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates -from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate -from langchain.prompts.example_selector.base import BaseExampleSelector -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import Extra, root_validator - - -class FewShotPromptWithTemplates(StringPromptTemplate): - """Prompt template that contains few shot examples.""" - - examples: Optional[List[dict]] = None - """Examples to format into the prompt. - Either this or example_selector should be provided.""" - - example_selector: Optional[BaseExampleSelector] = None - """ExampleSelector to choose the examples to format into the prompt. - Either this or examples should be provided.""" - - example_prompt: PromptTemplate - """PromptTemplate used to format an individual example.""" - - suffix: StringPromptTemplate - """A PromptTemplate to put after the examples.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - example_separator: str = "\n\n" - """String separator used to join the prefix, the examples, and suffix.""" - - prefix: Optional[StringPromptTemplate] = None - """A PromptTemplate to put before the examples.""" - - template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - - validate_template: bool = False - """Whether or not to try validating the template.""" - - @root_validator(pre=True) - def check_examples_and_selector(cls, values: Dict) -> Dict: - """Check that one and only one of examples/example_selector are provided.""" - examples = values.get("examples", None) - example_selector = values.get("example_selector", None) - if examples and example_selector: - raise ValueError( - "Only one of 'examples' and 'example_selector' should be provided" - ) - - if examples is None and example_selector is None: - raise ValueError( - "One of 'examples' and 'example_selector' should be provided" - ) - - return values - - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that prefix, suffix, and input variables are consistent.""" - if values["validate_template"]: - input_variables = values["input_variables"] - expected_input_variables = set(values["suffix"].input_variables) - expected_input_variables |= set(values["partial_variables"]) - if values["prefix"] is not None: - expected_input_variables |= set(values["prefix"].input_variables) - missing_vars = expected_input_variables.difference(input_variables) - if missing_vars: - raise ValueError( - f"Got input_variables={input_variables}, but based on " - f"prefix/suffix expected {expected_input_variables}" - ) - else: - values["input_variables"] = sorted( - set(values["suffix"].input_variables) - | set(values["prefix"].input_variables if values["prefix"] else []) - - set(values["partial_variables"]) - ) - return values - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def _get_examples(self, **kwargs: Any) -> List[dict]: - if self.examples is not None: - return self.examples - elif self.example_selector is not None: - return self.example_selector.select_examples(kwargs) - else: - raise ValueError - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - kwargs = self._merge_partial_and_user_variables(**kwargs) - # Get the examples to use. - examples = self._get_examples(**kwargs) - # Format the examples. - example_strings = [ - self.example_prompt.format(**example) for example in examples - ] - # Create the overall prefix. - if self.prefix is None: - prefix = "" - else: - prefix_kwargs = { - k: v for k, v in kwargs.items() if k in self.prefix.input_variables - } - for k in prefix_kwargs.keys(): - kwargs.pop(k) - prefix = self.prefix.format(**prefix_kwargs) - - # Create the overall suffix - suffix_kwargs = { - k: v for k, v in kwargs.items() if k in self.suffix.input_variables - } - for k in suffix_kwargs.keys(): - kwargs.pop(k) - suffix = self.suffix.format( - **suffix_kwargs, - ) - - pieces = [prefix, *example_strings, suffix] - template = self.example_separator.join([piece for piece in pieces if piece]) - # Format the template with the input variables. - return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - return "few_shot_with_templates" - - def save(self, file_path: Union[Path, str]) -> None: - if self.example_selector: - raise ValueError("Saving an example selector is not currently supported") - return super().save(file_path) +__all__ = ["FewShotPromptWithTemplates"] diff --git a/libs/langchain/langchain/prompts/loading.py b/libs/langchain/langchain/prompts/loading.py index d953cc74680..df0f62f8503 100644 --- a/libs/langchain/langchain/prompts/loading.py +++ b/libs/langchain/langchain/prompts/loading.py @@ -1,163 +1,4 @@ -"""Load prompts.""" -import json -import logging -from pathlib import Path -from typing import Callable, Dict, Union +from langchain_core.prompts.loading import load_prompt, load_prompt_from_config +from langchain_core.utils.loading import try_load_from_hub -import yaml - -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser -from langchain.utils.loading import try_load_from_hub - -URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" -logger = logging.getLogger(__name__) - - -def load_prompt_from_config(config: dict) -> BasePromptTemplate: - """Load prompt from Config Dict.""" - if "_type" not in config: - logger.warning("No `_type` key found, defaulting to `prompt`.") - config_type = config.pop("_type", "prompt") - - if config_type not in type_to_loader_dict: - raise ValueError(f"Loading {config_type} prompt not supported") - - prompt_loader = type_to_loader_dict[config_type] - return prompt_loader(config) - - -def _load_template(var_name: str, config: dict) -> dict: - """Load template from the path if applicable.""" - # Check if template_path exists in config. - if f"{var_name}_path" in config: - # If it does, make sure template variable doesn't also exist. - if var_name in config: - raise ValueError( - f"Both `{var_name}_path` and `{var_name}` cannot be provided." - ) - # Pop the template path from the config. - template_path = Path(config.pop(f"{var_name}_path")) - # Load the template. - if template_path.suffix == ".txt": - with open(template_path) as f: - template = f.read() - else: - raise ValueError - # Set the template variable to the extracted variable. - config[var_name] = template - return config - - -def _load_examples(config: dict) -> dict: - """Load examples if necessary.""" - if isinstance(config["examples"], list): - pass - elif isinstance(config["examples"], str): - with open(config["examples"]) as f: - if config["examples"].endswith(".json"): - examples = json.load(f) - elif config["examples"].endswith((".yaml", ".yml")): - examples = yaml.safe_load(f) - else: - raise ValueError( - "Invalid file format. Only json or yaml formats are supported." - ) - config["examples"] = examples - else: - raise ValueError("Invalid examples format. Only list or string are supported.") - return config - - -def _load_output_parser(config: dict) -> dict: - """Load output parser.""" - if "output_parser" in config and config["output_parser"]: - _config = config.pop("output_parser") - output_parser_type = _config.pop("_type") - if output_parser_type == "regex_parser": - from langchain.output_parsers.regex import RegexParser - - output_parser: BaseLLMOutputParser = RegexParser(**_config) - elif output_parser_type == "default": - output_parser = StrOutputParser(**_config) - else: - raise ValueError(f"Unsupported output parser {output_parser_type}") - config["output_parser"] = output_parser - return config - - -def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: - """Load the "few shot" prompt from the config.""" - # Load the suffix and prefix templates. - config = _load_template("suffix", config) - config = _load_template("prefix", config) - # Load the example prompt. - if "example_prompt_path" in config: - if "example_prompt" in config: - raise ValueError( - "Only one of example_prompt and example_prompt_path should " - "be specified." - ) - config["example_prompt"] = load_prompt(config.pop("example_prompt_path")) - else: - config["example_prompt"] = load_prompt_from_config(config["example_prompt"]) - # Load the examples. - config = _load_examples(config) - config = _load_output_parser(config) - return FewShotPromptTemplate(**config) - - -def _load_prompt(config: dict) -> PromptTemplate: - """Load the prompt template from config.""" - # Load the template from disk if necessary. - config = _load_template("template", config) - config = _load_output_parser(config) - - template_format = config.get("template_format", "f-string") - if template_format == "jinja2": - # Disabled due to: - # https://github.com/langchain-ai/langchain/issues/4394 - raise ValueError( - f"Loading templates with '{template_format}' format is no longer supported " - f"since it can lead to arbitrary code execution. Please migrate to using " - f"the 'f-string' template format, which does not suffer from this issue." - ) - - return PromptTemplate(**config) - - -def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: - """Unified method for loading a prompt from LangChainHub or local fs.""" - if hub_result := try_load_from_hub( - path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} - ): - return hub_result - else: - return _load_prompt_from_file(path) - - -def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: - """Load prompt from file.""" - # Convert file to a Path object. - if isinstance(file, str): - file_path = Path(file) - else: - file_path = file - # Load from either json or yaml. - if file_path.suffix == ".json": - with open(file_path) as f: - config = json.load(f) - elif file_path.suffix == ".yaml": - with open(file_path, "r") as f: - config = yaml.safe_load(f) - else: - raise ValueError(f"Got unsupported file type {file_path.suffix}") - # Load the prompt from the config now. - return load_prompt_from_config(config) - - -type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { - "prompt": _load_prompt, - "few_shot": _load_few_shot_prompt, -} +__all__ = ["load_prompt_from_config", "load_prompt", "try_load_from_hub"] diff --git a/libs/langchain/langchain/prompts/pipeline.py b/libs/langchain/langchain/prompts/pipeline.py index 4183b97387c..88e73e16f33 100644 --- a/libs/langchain/langchain/prompts/pipeline.py +++ b/libs/langchain/langchain/prompts/pipeline.py @@ -1,56 +1,3 @@ -from typing import Any, Dict, List, Tuple +from langchain_core.prompts.pipeline import PipelinePromptTemplate -from langchain.prompts.chat import BaseChatPromptTemplate -from langchain.pydantic_v1 import root_validator -from langchain.schema import BasePromptTemplate, PromptValue - - -def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: - return {k: inputs[k] for k in input_variables} - - -class PipelinePromptTemplate(BasePromptTemplate): - """A prompt template for composing multiple prompt templates together. - - This can be useful when you want to reuse parts of prompts. - A PipelinePrompt consists of two main parts: - - final_prompt: This is the final prompt that is returned - - pipeline_prompts: This is a list of tuples, consisting - of a string (`name`) and a Prompt Template. - Each PromptTemplate will be formatted and then passed - to future prompt templates as a variable with - the same name as `name` - """ - - final_prompt: BasePromptTemplate - """The final prompt that is returned.""" - pipeline_prompts: List[Tuple[str, BasePromptTemplate]] - """A list of tuples, consisting of a string (`name`) and a Prompt Template.""" - - @root_validator(pre=True) - def get_input_variables(cls, values: Dict) -> Dict: - """Get input variables.""" - created_variables = set() - all_variables = set() - for k, prompt in values["pipeline_prompts"]: - created_variables.add(k) - all_variables.update(prompt.input_variables) - values["input_variables"] = list(all_variables.difference(created_variables)) - return values - - def format_prompt(self, **kwargs: Any) -> PromptValue: - for k, prompt in self.pipeline_prompts: - _inputs = _get_inputs(kwargs, prompt.input_variables) - if isinstance(prompt, BaseChatPromptTemplate): - kwargs[k] = prompt.format_messages(**_inputs) - else: - kwargs[k] = prompt.format(**_inputs) - _inputs = _get_inputs(kwargs, self.final_prompt.input_variables) - return self.final_prompt.format_prompt(**_inputs) - - def format(self, **kwargs: Any) -> str: - return self.format_prompt(**kwargs).to_string() - - @property - def _prompt_type(self) -> str: - raise ValueError +__all__ = ["PipelinePromptTemplate"] diff --git a/libs/langchain/langchain/prompts/prompt.py b/libs/langchain/langchain/prompts/prompt.py index e935d15aae3..047d55adfed 100644 --- a/libs/langchain/langchain/prompts/prompt.py +++ b/libs/langchain/langchain/prompts/prompt.py @@ -1,250 +1,3 @@ -"""Prompt schema definition.""" -from __future__ import annotations +from langchain_core.prompts.prompt import PromptTemplate -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -from langchain.prompts.base import ( - DEFAULT_FORMATTER_MAPPING, - StringPromptTemplate, - check_valid_template, - get_template_variables, -) -from langchain.pydantic_v1 import root_validator - - -class PromptTemplate(StringPromptTemplate): - """A prompt template for a language model. - - A prompt template consists of a string template. It accepts a set of parameters - from the user that can be used to generate a prompt for a language model. - - The template can be formatted using either f-strings (default) or jinja2 syntax. - - *Security warning*: Prefer using `template_format="f-string"` instead of - `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates - from untrusted sources as they may lead to arbitrary Python code execution. - - As of LangChain 0.0.329, Jinja2 templates will be rendered using - Jinja2's SandboxedEnvironment by default. This sand-boxing should - be treated as a best-effort approach rather than a guarantee of security, - as it is an opt-out rather than opt-in approach. - - Despite the sand-boxing, we recommend to never use jinja2 templates - from untrusted sources. - - Example: - - .. code-block:: python - - from langchain.prompts import PromptTemplate - - # Instantiation using from_template (recommended) - prompt = PromptTemplate.from_template("Say {foo}") - prompt.format(foo="bar") - - # Instantiation using initializer - prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") - """ - - @property - def lc_attributes(self) -> Dict[str, Any]: - return { - "template_format": self.template_format, - } - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - - template: str - """The prompt template.""" - - template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - - validate_template: bool = False - """Whether or not to try validating the template.""" - - def __add__(self, other: Any) -> PromptTemplate: - """Override the + operator to allow for combining prompt templates.""" - # Allow for easy combining - if isinstance(other, PromptTemplate): - if self.template_format != "f-string": - raise ValueError( - "Adding prompt templates only supported for f-strings." - ) - if other.template_format != "f-string": - raise ValueError( - "Adding prompt templates only supported for f-strings." - ) - input_variables = list( - set(self.input_variables) | set(other.input_variables) - ) - template = self.template + other.template - # If any do not want to validate, then don't - validate_template = self.validate_template and other.validate_template - partial_variables = {k: v for k, v in self.partial_variables.items()} - for k, v in other.partial_variables.items(): - if k in partial_variables: - raise ValueError("Cannot have same variable partialed twice.") - else: - partial_variables[k] = v - return PromptTemplate( - template=template, - input_variables=input_variables, - partial_variables=partial_variables, - template_format="f-string", - validate_template=validate_template, - ) - elif isinstance(other, str): - prompt = PromptTemplate.from_template(other) - return self + prompt - else: - raise NotImplementedError(f"Unsupported operand type for +: {type(other)}") - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - return "prompt" - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - kwargs = self._merge_partial_and_user_variables(**kwargs) - return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) - - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that template and input variables are consistent.""" - if values["validate_template"]: - all_inputs = values["input_variables"] + list(values["partial_variables"]) - check_valid_template( - values["template"], values["template_format"], all_inputs - ) - elif values.get("template_format"): - values["input_variables"] = [ - var - for var in get_template_variables( - values["template"], values["template_format"] - ) - if var not in values["partial_variables"] - ] - return values - - @classmethod - def from_examples( - cls, - examples: List[str], - suffix: str, - input_variables: List[str], - example_separator: str = "\n\n", - prefix: str = "", - **kwargs: Any, - ) -> PromptTemplate: - """Take examples in list format with prefix and suffix to create a prompt. - - Intended to be used as a way to dynamically create a prompt from examples. - - Args: - examples: List of examples to use in the prompt. - suffix: String to go after the list of examples. Should generally - set up the user's input. - input_variables: A list of variable names the final prompt template - will expect. - example_separator: The separator to use in between examples. Defaults - to two new line characters. - prefix: String that should go before any examples. Generally includes - examples. Default to an empty string. - - Returns: - The final prompt generated. - """ - template = example_separator.join([prefix, *examples, suffix]) - return cls(input_variables=input_variables, template=template, **kwargs) - - @classmethod - def from_file( - cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any - ) -> PromptTemplate: - """Load a prompt from a file. - - Args: - template_file: The path to the file containing the prompt template. - input_variables: A list of variable names the final prompt template - will expect. - - Returns: - The prompt loaded from the file. - """ - with open(str(template_file), "r") as f: - template = f.read() - return cls(input_variables=input_variables, template=template, **kwargs) - - @classmethod - def from_template( - cls, - template: str, - *, - template_format: str = "f-string", - partial_variables: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> PromptTemplate: - """Load a prompt template from a template. - - *Security warning*: Prefer using `template_format="f-string"` instead of - `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates - from untrusted sources as they may lead to arbitrary Python code execution. - - As of LangChain 0.0.329, Jinja2 templates will be rendered using - Jinja2's SandboxedEnvironment by default. This sand-boxing should - be treated as a best-effort approach rather than a guarantee of security, - as it is an opt-out rather than opt-in approach. - - Despite the sand-boxing, we recommend to never use jinja2 templates - from untrusted sources. - - Args: - template: The template to load. - template_format: The format of the template. Use `jinja2` for jinja2, - and `f-string` or None for f-strings. - partial_variables: A dictionary of variables that can be used to partially - fill in the template. For example, if the template is - `"{variable1} {variable2}"`, and `partial_variables` is - `{"variable1": "foo"}`, then the final prompt will be - `"foo {variable2}"`. - - Returns: - The prompt template loaded from the template. - """ - - input_variables = get_template_variables(template, template_format) - _partial_variables = partial_variables or {} - - if _partial_variables: - input_variables = [ - var for var in input_variables if var not in _partial_variables - ] - - return cls( - input_variables=input_variables, - template=template, - template_format=template_format, - partial_variables=_partial_variables, - **kwargs, - ) - - -# For backwards compatibility. -Prompt = PromptTemplate +__all__ = ["PromptTemplate"] diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index 6bfba7eef9a..a3bd2883e34 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BaseRetriever from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/retrievers/arxiv.py b/libs/langchain/langchain/retrievers/arxiv.py index b14f0b61f8a..50880935621 100644 --- a/libs/langchain/langchain/retrievers/arxiv.py +++ b/libs/langchain/langchain/retrievers/arxiv.py @@ -1,7 +1,8 @@ from typing import List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document from langchain.utilities.arxiv import ArxivAPIWrapper diff --git a/libs/langchain/langchain/retrievers/azure_cognitive_search.py b/libs/langchain/langchain/retrievers/azure_cognitive_search.py index e95e8bd4929..4a75291a66b 100644 --- a/libs/langchain/langchain/retrievers/azure_cognitive_search.py +++ b/libs/langchain/langchain/retrievers/azure_cognitive_search.py @@ -5,13 +5,13 @@ from typing import Dict, List, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BaseRetriever, Document from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/retrievers/bm25.py b/libs/langchain/langchain/retrievers/bm25.py index 2a03b8b2df6..f1868e9b0e7 100644 --- a/libs/langchain/langchain/retrievers/bm25.py +++ b/libs/langchain/langchain/retrievers/bm25.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import Any, Callable, Dict, Iterable, List, Optional +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document def default_preprocessing_func(text: str) -> List[str]: diff --git a/libs/langchain/langchain/retrievers/chaindesk.py b/libs/langchain/langchain/retrievers/chaindesk.py index f4f85802f12..71bdae23286 100644 --- a/libs/langchain/langchain/retrievers/chaindesk.py +++ b/libs/langchain/langchain/retrievers/chaindesk.py @@ -2,12 +2,12 @@ from typing import Any, List, Optional import aiohttp import requests +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document class ChaindeskRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/chatgpt_plugin_retriever.py b/libs/langchain/langchain/retrievers/chatgpt_plugin_retriever.py index e279467609b..e38c9f85da7 100644 --- a/libs/langchain/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/libs/langchain/langchain/retrievers/chatgpt_plugin_retriever.py @@ -4,12 +4,12 @@ from typing import List, Optional import aiohttp import requests +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document class ChatGPTPluginRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py index 9d79adee69f..7c1fa4ac58e 100644 --- a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py +++ b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py @@ -2,16 +2,17 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseRetriever, Document, HumanMessage + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import Field -from langchain.schema import BaseRetriever, Document, HumanMessage if TYPE_CHECKING: - from langchain.schema.messages import BaseMessage + from langchain_core.schema.messages import BaseMessage def _get_docs(response: Any) -> List[Document]: diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index 0a5654b052b..1ef429ecddb 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -1,5 +1,7 @@ from typing import Any, List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -7,7 +9,6 @@ from langchain.callbacks.manager import ( from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, ) -from langchain.schema import BaseRetriever, Document class ContextualCompressionRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/databerry.py b/libs/langchain/langchain/retrievers/databerry.py index 4113f99c2fd..6913a642891 100644 --- a/libs/langchain/langchain/retrievers/databerry.py +++ b/libs/langchain/langchain/retrievers/databerry.py @@ -2,12 +2,12 @@ from typing import List, Optional import aiohttp import requests +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document class DataberryRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/docarray.py b/libs/langchain/langchain/retrievers/docarray.py index 02c4f9fb3a8..8c20498a14c 100644 --- a/libs/langchain/langchain/retrievers/docarray.py +++ b/libs/langchain/langchain/retrievers/docarray.py @@ -2,10 +2,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import numpy as np +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.utils import maximal_marginal_relevance diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index a468f097013..42874801965 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from inspect import signature from typing import List, Optional, Sequence, Union +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema import BaseDocumentTransformer, Document + from langchain.callbacks.manager import Callbacks -from langchain.pydantic_v1 import BaseModel -from langchain.schema import BaseDocumentTransformer, Document class BaseDocumentCompressor(BaseModel, ABC): diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index 7fc00416408..c6ff92f1471 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -4,15 +4,16 @@ from __future__ import annotations import asyncio from typing import Any, Callable, Dict, Optional, Sequence +from langchain_core.prompts import PromptTemplate +from langchain_core.schema import BaseOutputParser, Document +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain -from langchain.prompts import PromptTemplate from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_extract_prompt import ( prompt_template, ) -from langchain.schema import BaseOutputParser, Document -from langchain.schema.language_model import BaseLanguageModel def default_get_input(query: str, doc: Document) -> Dict[str, Any]: diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 538f683fa61..5ce340480e0 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -1,16 +1,17 @@ """Filter that uses an LLM to drop documents that aren't relevant to the query.""" from typing import Any, Callable, Dict, Optional, Sequence +from langchain_core.prompts import PromptTemplate +from langchain_core.schema import BasePromptTemplate, Document +from langchain_core.schema.language_model import BaseLanguageModel + from langchain.callbacks.manager import Callbacks from langchain.chains import LLMChain from langchain.output_parsers.boolean import BooleanOutputParser -from langchain.prompts import PromptTemplate from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_filter_prompt import ( prompt_template, ) -from langchain.schema import BasePromptTemplate, Document -from langchain.schema.language_model import BaseLanguageModel def _get_default_chain_prompt() -> PromptTemplate: diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index d8790b32390..31383253438 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -2,10 +2,11 @@ from __future__ import annotations from typing import TYPE_CHECKING, Dict, Optional, Sequence +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import Document + from langchain.callbacks.manager import Callbacks -from langchain.pydantic_v1 import Extra, root_validator from langchain.retrievers.document_compressors.base import BaseDocumentCompressor -from langchain.schema import Document from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 9241c3bc595..5b02bd48bb6 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -1,18 +1,18 @@ from typing import Callable, Dict, Optional, Sequence import numpy as np +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import Callbacks from langchain.document_transformers.embeddings_redundant_filter import ( _get_embeddings_from_stateful_docs, get_stateful_documents, ) -from langchain.pydantic_v1 import root_validator from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, ) -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings from langchain.utils.math import cosine_similarity diff --git a/libs/langchain/langchain/retrievers/elastic_search_bm25.py b/libs/langchain/langchain/retrievers/elastic_search_bm25.py index 68fb4323f25..cb69b8e3dd0 100644 --- a/libs/langchain/langchain/retrievers/elastic_search_bm25.py +++ b/libs/langchain/langchain/retrievers/elastic_search_bm25.py @@ -5,9 +5,10 @@ from __future__ import annotations import uuid from typing import Any, Iterable, List +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.schema import BaseRetriever class ElasticSearchBM25Retriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 93879e4d3cd..df475d1dcaf 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -4,12 +4,13 @@ multiple retrievers by using weighted Reciprocal Rank Fusion """ from typing import Any, Dict, List +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document class EnsembleRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py index a99f35264f2..e7f5e65cbbd 100644 --- a/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py +++ b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py @@ -1,10 +1,11 @@ """Retriever wrapper for Google Cloud Document AI Warehouse.""" from typing import TYPE_CHECKING, Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever from langchain.utilities.vertexai import get_client_info from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py index 0e0c165c53e..786920b6625 100644 --- a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py +++ b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py @@ -3,9 +3,10 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema import BaseRetriever, Document from langchain.utilities.vertexai import get_client_info from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/retrievers/kay.py b/libs/langchain/langchain/retrievers/kay.py index 47e0471c156..e3aac4dee3a 100644 --- a/libs/langchain/langchain/retrievers/kay.py +++ b/libs/langchain/langchain/retrievers/kay.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import Any, List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document class KayAiRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/kendra.py b/libs/langchain/langchain/retrievers/kendra.py index 55382993b3d..7667fc7c316 100644 --- a/libs/langchain/langchain/retrievers/kendra.py +++ b/libs/langchain/langchain/retrievers/kendra.py @@ -2,10 +2,11 @@ import re from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator, validator +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseModel, Extra, root_validator, validator -from langchain.schema import BaseRetriever def clean_excerpt(excerpt: str) -> str: diff --git a/libs/langchain/langchain/retrievers/knn.py b/libs/langchain/langchain/retrievers/knn.py index e8013972e1d..e5ef90907ec 100644 --- a/libs/langchain/langchain/retrievers/knn.py +++ b/libs/langchain/langchain/retrievers/knn.py @@ -8,10 +8,10 @@ import concurrent.futures from typing import Any, List, Optional import numpy as np +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: diff --git a/libs/langchain/langchain/retrievers/llama_index.py b/libs/langchain/langchain/retrievers/llama_index.py index 81f30d2104c..e602bf72fc9 100644 --- a/libs/langchain/langchain/retrievers/llama_index.py +++ b/libs/langchain/langchain/retrievers/llama_index.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, cast +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import Field -from langchain.schema import BaseRetriever, Document class LlamaIndexRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/merger_retriever.py b/libs/langchain/langchain/retrievers/merger_retriever.py index 962fb60225e..6f12050042f 100644 --- a/libs/langchain/langchain/retrievers/merger_retriever.py +++ b/libs/langchain/langchain/retrievers/merger_retriever.py @@ -1,11 +1,12 @@ import asyncio from typing import List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document class MergerRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/metal.py b/libs/langchain/langchain/retrievers/metal.py index aed77b45137..f3039ddffc6 100644 --- a/libs/langchain/langchain/retrievers/metal.py +++ b/libs/langchain/langchain/retrievers/metal.py @@ -1,8 +1,9 @@ from typing import Any, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document class MetalRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/milvus.py b/libs/langchain/langchain/retrievers/milvus.py index eadc7d81236..121474aae7b 100644 --- a/libs/langchain/langchain/retrievers/milvus.py +++ b/libs/langchain/langchain/retrievers/milvus.py @@ -2,10 +2,11 @@ import warnings from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.milvus import Milvus # TODO: Update to MilvusClient + Hybrid Search when available diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 3ac2beaa445..652ae499494 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -2,6 +2,10 @@ import asyncio import logging from typing import List, Sequence +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -9,9 +13,6 @@ from langchain.callbacks.manager import ( from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain.output_parsers.pydantic import PydanticOutputParser -from langchain.prompts.prompt import PromptTemplate -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BaseRetriever, Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 627f6e3f661..fec3b950cc9 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -1,9 +1,10 @@ from typing import List +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseRetriever, BaseStore, Document +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import Field -from langchain.schema import BaseRetriever, BaseStore, Document -from langchain.schema.vectorstore import VectorStore class MultiVectorRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/parent_document_retriever.py b/libs/langchain/langchain/retrievers/parent_document_retriever.py index dd5aa720675..b0d3d3ce20d 100644 --- a/libs/langchain/langchain/retrievers/parent_document_retriever.py +++ b/libs/langchain/langchain/retrievers/parent_document_retriever.py @@ -1,8 +1,9 @@ import uuid from typing import List, Optional +from langchain_core.schema.document import Document + from langchain.retrievers import MultiVectorRetriever -from langchain.schema.document import Document from langchain.text_splitter import TextSplitter diff --git a/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py b/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py index 98563ed8a63..cfe73ed4310 100644 --- a/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py +++ b/libs/langchain/langchain/retrievers/pinecone_hybrid_search.py @@ -3,10 +3,11 @@ import hashlib from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings def hash_text(text: str) -> str: diff --git a/libs/langchain/langchain/retrievers/pubmed.py b/libs/langchain/langchain/retrievers/pubmed.py index b87441e170c..a3600d3b9fd 100644 --- a/libs/langchain/langchain/retrievers/pubmed.py +++ b/libs/langchain/langchain/retrievers/pubmed.py @@ -1,7 +1,8 @@ from typing import List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document from langchain.utilities.pubmed import PubMedAPIWrapper diff --git a/libs/langchain/langchain/retrievers/re_phraser.py b/libs/langchain/langchain/retrievers/re_phraser.py index eba5910d430..350866af4f5 100644 --- a/libs/langchain/langchain/retrievers/re_phraser.py +++ b/libs/langchain/langchain/retrievers/re_phraser.py @@ -1,14 +1,15 @@ import logging from typing import List +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseRetriever, Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/retrievers/remote_retriever.py b/libs/langchain/langchain/retrievers/remote_retriever.py index 3df22253e7d..a5ebd8ef63c 100644 --- a/libs/langchain/langchain/retrievers/remote_retriever.py +++ b/libs/langchain/langchain/retrievers/remote_retriever.py @@ -2,12 +2,12 @@ from typing import List, Optional import aiohttp import requests +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document class RemoteLangChainRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 2390351b52b..7739dfbe72c 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -2,6 +2,12 @@ import logging from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.runnables import Runnable +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -9,7 +15,6 @@ from langchain.callbacks.manager import ( from langchain.chains.query_constructor.base import load_query_constructor_runnable from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.retrievers.self_query.chroma import ChromaTranslator from langchain.retrievers.self_query.dashvector import DashvectorTranslator from langchain.retrievers.self_query.deeplake import DeepLakeTranslator @@ -24,10 +29,6 @@ from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator from langchain.retrievers.self_query.vectara import VectaraTranslator from langchain.retrievers.self_query.weaviate import WeaviateTranslator -from langchain.schema import BaseRetriever, Document -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.runnable import Runnable -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores import ( Chroma, DashVector, diff --git a/libs/langchain/langchain/retrievers/svm.py b/libs/langchain/langchain/retrievers/svm.py index f7faabfca76..99cff11fc18 100644 --- a/libs/langchain/langchain/retrievers/svm.py +++ b/libs/langchain/langchain/retrievers/svm.py @@ -4,10 +4,10 @@ import concurrent.futures from typing import Any, Iterable, List, Optional import numpy as np +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: diff --git a/libs/langchain/langchain/retrievers/tavily_search_api.py b/libs/langchain/langchain/retrievers/tavily_search_api.py index 48f95abb8b0..5fdbc28eeae 100644 --- a/libs/langchain/langchain/retrievers/tavily_search_api.py +++ b/libs/langchain/langchain/retrievers/tavily_search_api.py @@ -2,9 +2,10 @@ import os from enum import Enum from typing import Any, Dict, List, Optional +from langchain_core.schema import Document +from langchain_core.schema.retriever import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import Document -from langchain.schema.retriever import BaseRetriever class SearchDepth(Enum): diff --git a/libs/langchain/langchain/retrievers/tfidf.py b/libs/langchain/langchain/retrievers/tfidf.py index fbc4e387926..f81bc25f53a 100644 --- a/libs/langchain/langchain/retrievers/tfidf.py +++ b/libs/langchain/langchain/retrievers/tfidf.py @@ -4,8 +4,9 @@ import pickle from pathlib import Path from typing import Any, Dict, Iterable, List, Optional +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document class TFIDFRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index de94805a535..e3a901c9b6f 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -2,10 +2,11 @@ import datetime from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import Field -from langchain.schema import BaseRetriever, Document -from langchain.schema.vectorstore import VectorStore def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: diff --git a/libs/langchain/langchain/retrievers/vespa_retriever.py b/libs/langchain/langchain/retrievers/vespa_retriever.py index 54f64628a50..17b7e07e6bd 100644 --- a/libs/langchain/langchain/retrievers/vespa_retriever.py +++ b/libs/langchain/langchain/retrievers/vespa_retriever.py @@ -3,8 +3,9 @@ from __future__ import annotations import json from typing import Any, Dict, List, Literal, Optional, Sequence, Union +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document class VespaRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py b/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py index e0a366406fd..b28a643e8f7 100644 --- a/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py +++ b/libs/langchain/langchain/retrievers/weaviate_hybrid_search.py @@ -3,10 +3,11 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, cast from uuid import uuid4 +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever + from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever class WeaviateHybridSearchRetriever(BaseRetriever): diff --git a/libs/langchain/langchain/retrievers/web_research.py b/libs/langchain/langchain/retrievers/web_research.py index 73b822dd965..f1ebe46ef3f 100644 --- a/libs/langchain/langchain/retrievers/web_research.py +++ b/libs/langchain/langchain/retrievers/web_research.py @@ -2,6 +2,11 @@ import logging import re from typing import List, Optional +from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -13,10 +18,6 @@ from langchain.document_transformers import Html2TextTransformer from langchain.llms import LlamaCpp from langchain.llms.base import BaseLLM from langchain.output_parsers.pydantic import PydanticOutputParser -from langchain.prompts import BasePromptTemplate, PromptTemplate -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BaseRetriever, Document -from langchain.schema.vectorstore import VectorStore from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.utilities import GoogleSearchAPIWrapper diff --git a/libs/langchain/langchain/retrievers/wikipedia.py b/libs/langchain/langchain/retrievers/wikipedia.py index 7b6b8c3f052..fccbfa1508f 100644 --- a/libs/langchain/langchain/retrievers/wikipedia.py +++ b/libs/langchain/langchain/retrievers/wikipedia.py @@ -1,7 +1,8 @@ from typing import List +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.schema import BaseRetriever, Document from langchain.utilities.wikipedia import WikipediaAPIWrapper diff --git a/libs/langchain/langchain/retrievers/you.py b/libs/langchain/langchain/retrievers/you.py index 2cc0476858d..e0bb437e0f6 100644 --- a/libs/langchain/langchain/retrievers/you.py +++ b/libs/langchain/langchain/retrievers/you.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/retrievers/zep.py b/libs/langchain/langchain/retrievers/zep.py index 8253a35082d..88a6a5f5037 100644 --- a/libs/langchain/langchain/retrievers/zep.py +++ b/libs/langchain/langchain/retrievers/zep.py @@ -3,12 +3,13 @@ from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document if TYPE_CHECKING: from zep_python.memory import MemorySearchResult diff --git a/libs/langchain/langchain/retrievers/zilliz.py b/libs/langchain/langchain/retrievers/zilliz.py index 5527b0faf6d..f144ca4cc2b 100644 --- a/libs/langchain/langchain/retrievers/zilliz.py +++ b/libs/langchain/langchain/retrievers/zilliz.py @@ -1,10 +1,11 @@ import warnings from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever, Document +from langchain_core.schema.embeddings import Embeddings + from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever, Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.zilliz import Zilliz # TODO: Update to ZillizClient + Hybrid Search when available diff --git a/libs/langchain/langchain/runnables/hub.py b/libs/langchain/langchain/runnables/hub.py index 64dbe2f6180..71fad87ba67 100644 --- a/libs/langchain/langchain/runnables/hub.py +++ b/libs/langchain/langchain/runnables/hub.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from langchain.schema.runnable.base import Input, Output, RunnableBindingBase +from langchain_core.runnables.base import Input, Output, RunnableBindingBase class HubRunnable(RunnableBindingBase[Input, Output]): diff --git a/libs/langchain/langchain/runnables/openai_functions.py b/libs/langchain/langchain/runnables/openai_functions.py index cdabef48fc0..f03d76b6495 100644 --- a/libs/langchain/langchain/runnables/openai_functions.py +++ b/libs/langchain/langchain/runnables/openai_functions.py @@ -1,12 +1,12 @@ from operator import itemgetter from typing import Any, Callable, List, Mapping, Optional, Union +from langchain_core.runnables import RouterRunnable, Runnable +from langchain_core.runnables.base import RunnableBindingBase +from langchain_core.schema.messages import BaseMessage from typing_extensions import TypedDict from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser -from langchain.schema.messages import BaseMessage -from langchain.schema.runnable import RouterRunnable, Runnable -from langchain.schema.runnable.base import RunnableBindingBase class OpenAIFunction(TypedDict): diff --git a/libs/langchain/langchain/schema/__init__.py b/libs/langchain/langchain/schema/__init__.py index be830b10aa9..7e1742cc327 100644 --- a/libs/langchain/langchain/schema/__init__.py +++ b/libs/langchain/langchain/schema/__init__.py @@ -1,11 +1,11 @@ """**Schemas** are the LangChain Base Classes and Interfaces.""" -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.cache import BaseCache -from langchain.schema.chat_history import BaseChatMessageHistory -from langchain.schema.document import BaseDocumentTransformer, Document -from langchain.schema.exceptions import LangChainException -from langchain.schema.memory import BaseMemory -from langchain.schema.messages import ( +from langchain_core.schema.agent import AgentAction, AgentFinish +from langchain_core.schema.cache import BaseCache +from langchain_core.schema.chat_history import BaseChatMessageHistory +from langchain_core.schema.document import BaseDocumentTransformer, Document +from langchain_core.schema.exceptions import LangChainException +from langchain_core.schema.memory import BaseMemory +from langchain_core.schema.messages import ( AIMessage, BaseMessage, ChatMessage, @@ -18,23 +18,23 @@ from langchain.schema.messages import ( messages_from_dict, messages_to_dict, ) -from langchain.schema.output import ( +from langchain_core.schema.output import ( ChatGeneration, ChatResult, Generation, LLMResult, RunInfo, ) -from langchain.schema.output_parser import ( +from langchain_core.schema.output_parser import ( BaseLLMOutputParser, BaseOutputParser, OutputParserException, StrOutputParser, ) -from langchain.schema.prompt import PromptValue -from langchain.schema.prompt_template import BasePromptTemplate, format_document -from langchain.schema.retriever import BaseRetriever -from langchain.schema.storage import BaseStore +from langchain_core.schema.prompt import PromptValue +from langchain_core.schema.prompt_template import BasePromptTemplate, format_document +from langchain_core.schema.retriever import BaseRetriever +from langchain_core.schema.storage import BaseStore RUN_KEY = "__run" Memory = BaseMemory diff --git a/libs/langchain/langchain/schema/agent.py b/libs/langchain/langchain/schema/agent.py index 447dcbd4415..f5064f91811 100644 --- a/libs/langchain/langchain/schema/agent.py +++ b/libs/langchain/langchain/schema/agent.py @@ -1,74 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish -from typing import Any, Literal, Sequence, Union - -from langchain.load.serializable import Serializable -from langchain.schema.messages import BaseMessage - - -class AgentAction(Serializable): - """A full description of an action for an ActionAgent to execute.""" - - tool: str - """The name of the Tool to execute.""" - tool_input: Union[str, dict] - """The input to pass in to the Tool.""" - log: str - """Additional information to log about the action. - This log can be used in a few ways. First, it can be used to audit - what exactly the LLM predicted to lead to this (tool, tool_input). - Second, it can be used in future iterations to show the LLMs prior - thoughts. This is useful when (tool, tool_input) does not contain - full information about the LLM prediction (for example, any `thought` - before the tool/tool_input).""" - type: Literal["AgentAction"] = "AgentAction" - - def __init__( - self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any - ): - """Override init to support instantiation by position for backward compat.""" - super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable.""" - return True - - -class AgentActionMessageLog(AgentAction): - message_log: Sequence[BaseMessage] - """Similar to log, this can be used to pass along extra - information about what exact messages were predicted by the LLM - before parsing out the (tool, tool_input). This is again useful - if (tool, tool_input) cannot be used to fully recreate the LLM - prediction, and you need that LLM prediction (for future agent iteration). - Compared to `log`, this is useful when the underlying LLM is a - ChatModel (and therefore returns messages rather than a string).""" - # Ignoring type because we're overriding the type from AgentAction. - # And this is the correct thing to do in this case. - # The type literal is used for serialization purposes. - type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore - - -class AgentFinish(Serializable): - """The final return value of an ActionAgent.""" - - return_values: dict - """Dictionary of return values.""" - log: str - """Additional information to log about the return value. - This is used to pass along the full LLM prediction, not just the parsed out - return value. For example, if the full LLM prediction was - `Final Answer: 2` you may want to just return `2` as a return value, but pass - along the full string as a `log` (for debugging or observability purposes). - """ - type: Literal["AgentFinish"] = "AgentFinish" - - def __init__(self, return_values: dict, log: str, **kwargs: Any): - """Override init to support instantiation by position for backward compat.""" - super().__init__(return_values=return_values, log=log, **kwargs) - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether or not the class is serializable.""" - return True +__all__ = ["AgentAction", "AgentActionMessageLog", "AgentFinish"] diff --git a/libs/langchain/langchain/schema/cache.py b/libs/langchain/langchain/schema/cache.py index 7adb07fd1db..145b1674354 100644 --- a/libs/langchain/langchain/schema/cache.py +++ b/libs/langchain/langchain/schema/cache.py @@ -1,24 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.cache import BaseCache -from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence - -from langchain.schema.output import Generation - -RETURN_VAL_TYPE = Sequence[Generation] - - -class BaseCache(ABC): - """Base interface for cache.""" - - @abstractmethod - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Look up based on prompt and llm_string.""" - - @abstractmethod - def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - """Update cache based on prompt and llm_string.""" - - @abstractmethod - def clear(self, **kwargs: Any) -> None: - """Clear cache that can take additional keyword arguments.""" +__all__ = ["BaseCache"] diff --git a/libs/langchain/langchain/schema/callbacks/base.py b/libs/langchain/langchain/schema/callbacks/base.py index 359496aa32a..ab2053a075c 100644 --- a/libs/langchain/langchain/schema/callbacks/base.py +++ b/libs/langchain/langchain/schema/callbacks/base.py @@ -1,598 +1,23 @@ -"""Base callback handler that can be used to handle callbacks in langchain.""" -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union -from uuid import UUID - -from tenacity import RetryCallState - -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.document import Document -from langchain.schema.messages import BaseMessage -from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult - - -class RetrieverManagerMixin: - """Mixin for Retriever callbacks.""" - - def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever errors.""" - - def on_retriever_end( - self, - documents: Sequence[Document], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever ends running.""" - - -class LLMManagerMixin: - """Mixin for LLM callbacks.""" - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on new LLM token. Only available when streaming is enabled. - - Args: - token (str): The new token. - chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, - containing content and other information. - """ - - def on_llm_end( - self, - response: LLMResult, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM ends running.""" - - def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM errors.""" - - -class ChainManagerMixin: - """Mixin for chain callbacks.""" - - def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when chain ends running.""" - - def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when chain errors.""" - - def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on agent action.""" - - def on_agent_finish( - self, - finish: AgentFinish, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on agent end.""" - - -class ToolManagerMixin: - """Mixin for tool callbacks.""" - - def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when tool ends running.""" - - def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run when tool errors.""" - - -class CallbackManagerMixin: - """Mixin for callback manager.""" - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when LLM starts running.""" - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when a chat model starts running.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not implement `on_chat_model_start`" - ) - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when Retriever starts running.""" - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when chain starts running.""" - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when tool starts running.""" - - -class RunManagerMixin: - """Mixin for run manager.""" - - def on_text( - self, - text: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on arbitrary text.""" - - def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on a retry event.""" - - -class BaseCallbackHandler( - LLMManagerMixin, - ChainManagerMixin, - ToolManagerMixin, - RetrieverManagerMixin, +from langchain_core.callbacks.base import ( + AsyncCallbackHandler, + BaseCallbackHandler, + BaseCallbackManager, CallbackManagerMixin, + ChainManagerMixin, + LLMManagerMixin, + RetrieverManagerMixin, RunManagerMixin, -): - """Base callback handler that handles callbacks from LangChain.""" + ToolManagerMixin, +) - raise_error: bool = False - - run_inline: bool = False - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return False - - @property - def ignore_retry(self) -> bool: - """Whether to ignore retry callbacks.""" - return False - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return False - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return False - - @property - def ignore_retriever(self) -> bool: - """Whether to ignore retriever callbacks.""" - return False - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return False - - -class AsyncCallbackHandler(BaseCallbackHandler): - """Async callback handler that handles callbacks from LangChain.""" - - async def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM starts running.""" - - async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - """Run when a chat model starts running.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not implement `on_chat_model_start`" - ) - - async def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - - async def on_llm_end( - self, - response: LLMResult, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM ends running.""" - - async def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM errors.""" - - async def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when chain starts running.""" - - async def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when chain ends running.""" - - async def on_chain_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when chain errors.""" - - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run when tool starts running.""" - - async def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when tool ends running.""" - - async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run when tool errors.""" - - async def on_text( - self, - text: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on arbitrary text.""" - - async def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - """Run on a retry event.""" - - async def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on agent action.""" - - async def on_agent_finish( - self, - finish: AgentFinish, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on agent end.""" - - async def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever start.""" - - async def on_retriever_end( - self, - documents: Sequence[Document], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever end.""" - - async def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - """Run on retriever error.""" - - -T = TypeVar("T", bound="BaseCallbackManager") - - -class BaseCallbackManager(CallbackManagerMixin): - """Base callback manager that handles callbacks from LangChain.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - self.inheritable_handlers: List[BaseCallbackHandler] = ( - inheritable_handlers or [] - ) - self.parent_run_id: Optional[UUID] = parent_run_id - self.tags = tags or [] - self.inheritable_tags = inheritable_tags or [] - self.metadata = metadata or {} - self.inheritable_metadata = inheritable_metadata or {} - - def copy(self: T) -> T: - """Copy the callback manager.""" - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @property - def is_async(self) -> bool: - """Whether the callback manager is async.""" - return False - - def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Add a handler to the callback manager.""" - if handler not in self.handlers: - self.handlers.append(handler) - if inherit and handler not in self.inheritable_handlers: - self.inheritable_handlers.append(handler) - - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - self.handlers.remove(handler) - self.inheritable_handlers.remove(handler) - - def set_handlers( - self, handlers: List[BaseCallbackHandler], inherit: bool = True - ) -> None: - """Set handlers as the only handlers on the callback manager.""" - self.handlers = [] - self.inheritable_handlers = [] - for handler in handlers: - self.add_handler(handler, inherit=inherit) - - def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: - """Set handler as the only handler on the callback manager.""" - self.set_handlers([handler], inherit=inherit) - - def add_tags(self, tags: List[str], inherit: bool = True) -> None: - for tag in tags: - if tag in self.tags: - self.remove_tags([tag]) - self.tags.extend(tags) - if inherit: - self.inheritable_tags.extend(tags) - - def remove_tags(self, tags: List[str]) -> None: - for tag in tags: - self.tags.remove(tag) - self.inheritable_tags.remove(tag) - - def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: - self.metadata.update(metadata) - if inherit: - self.inheritable_metadata.update(metadata) - - def remove_metadata(self, keys: List[str]) -> None: - for key in keys: - self.metadata.pop(key) - self.inheritable_metadata.pop(key) - - -Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] +__all__ = [ + "RetrieverManagerMixin", + "LLMManagerMixin", + "ChainManagerMixin", + "ToolManagerMixin", + "CallbackManagerMixin", + "RunManagerMixin", + "BaseCallbackHandler", + "AsyncCallbackHandler", + "BaseCallbackManager", +] diff --git a/libs/langchain/langchain/schema/callbacks/manager.py b/libs/langchain/langchain/schema/callbacks/manager.py index 0491f1fed74..71e967faf2d 100644 --- a/libs/langchain/langchain/schema/callbacks/manager.py +++ b/libs/langchain/langchain/schema/callbacks/manager.py @@ -1,2075 +1,53 @@ -from __future__ import annotations - -import asyncio -import functools -import logging -import os -import uuid -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager, contextmanager -from contextvars import ContextVar -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Coroutine, - Dict, - Generator, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) -from uuid import UUID - -from langsmith import utils as ls_utils -from langsmith.run_helpers import get_run_tree_context -from tenacity import RetryCallState - -from langchain.schema import ( - AgentAction, - AgentFinish, - Document, - LLMResult, -) -from langchain.schema.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, - Callbacks, - ChainManagerMixin, - LLMManagerMixin, - RetrieverManagerMixin, - RunManagerMixin, - ToolManagerMixin, -) -from langchain.schema.callbacks.stdout import StdOutCallbackHandler -from langchain.schema.callbacks.tracers import run_collector -from langchain.schema.callbacks.tracers.langchain import ( - LangChainTracer, -) -from langchain.schema.callbacks.tracers.langchain_v1 import ( - LangChainTracerV1, - TracerSessionV1, -) -from langchain.schema.callbacks.tracers.stdout import ConsoleCallbackHandler -from langchain.schema.messages import BaseMessage, get_buffer_string -from langchain.schema.output import ChatGenerationChunk, GenerationChunk - -if TYPE_CHECKING: - from langsmith import Client as LangSmithClient - -logger = logging.getLogger(__name__) - -tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501 - "tracing_callback", default=None -) - -tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501 - "tracing_callback_v2", default=None -) -run_collector_var: ContextVar[ - Optional[run_collector.RunCollectorCallbackHandler] -] = ContextVar( # noqa: E501 - "run_collector", default=None -) - - -def _get_debug() -> bool: - from langchain.globals import get_debug - - return get_debug() - - -@contextmanager -def tracing_enabled( - session_name: str = "default", -) -> Generator[TracerSessionV1, None, None]: - """Get the Deprecated LangChainTracer in a context manager. - - Args: - session_name (str, optional): The name of the session. - Defaults to "default". - - Returns: - TracerSessionV1: The LangChainTracer session. - - Example: - >>> with tracing_enabled() as session: - ... # Use the LangChainTracer session - """ - cb = LangChainTracerV1() - session = cast(TracerSessionV1, cb.load_session(session_name)) - try: - tracing_callback_var.set(cb) - yield session - finally: - tracing_callback_var.set(None) - - -@contextmanager -def tracing_v2_enabled( - project_name: Optional[str] = None, - *, - example_id: Optional[Union[str, UUID]] = None, - tags: Optional[List[str]] = None, - client: Optional[LangSmithClient] = None, -) -> Generator[LangChainTracer, None, None]: - """Instruct LangChain to log all runs in context to LangSmith. - - Args: - project_name (str, optional): The name of the project. - Defaults to "default". - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - tags (List[str], optional): The tags to add to the run. - Defaults to None. - - Returns: - None - - Example: - >>> with tracing_v2_enabled(): - ... # LangChain code will automatically be traced - - You can use this to fetch the LangSmith run URL: - - >>> with tracing_v2_enabled() as cb: - ... chain.invoke("foo") - ... run_url = cb.get_run_url() - """ - if isinstance(example_id, str): - example_id = UUID(example_id) - cb = LangChainTracer( - example_id=example_id, - project_name=project_name, - tags=tags, - client=client, - ) - try: - tracing_v2_callback_var.set(cb) - yield cb - finally: - tracing_v2_callback_var.set(None) - - -@contextmanager -def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]: - """Collect all run traces in context. - - Returns: - run_collector.RunCollectorCallbackHandler: The run collector callback handler. - - Example: - >>> with collect_runs() as runs_cb: - chain.invoke("foo") - run_id = runs_cb.traced_runs[0].id - """ - cb = run_collector.RunCollectorCallbackHandler() - run_collector_var.set(cb) - yield cb - run_collector_var.set(None) - - -def _get_trace_callbacks( - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, -) -> Callbacks: - if _tracing_v2_is_enabled(): - project_name_ = project_name or _get_tracer_project() - tracer = tracing_v2_callback_var.get() or LangChainTracer( - project_name=project_name_, - example_id=example_id, - ) - if callback_manager is None: - cb = cast(Callbacks, [tracer]) - else: - if not any( - isinstance(handler, LangChainTracer) - for handler in callback_manager.handlers - ): - callback_manager.add_handler(tracer, True) - # If it already has a LangChainTracer, we don't need to add another one. - # this would likely mess up the trace hierarchy. - cb = callback_manager - else: - cb = None - return cb - - -@contextmanager -def trace_as_chain_group( - group_name: str, - callback_manager: Optional[CallbackManager] = None, - *, - inputs: Optional[Dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, -) -> Generator[CallbackManagerForChainGroup, None, None]: - """Get a callback manager for a chain group in a context manager. - Useful for grouping different calls together as a single run even if - they aren't composed in a single chain. - - Args: - group_name (str): The name of the chain group. - callback_manager (CallbackManager, optional): The callback manager to use. - inputs (Dict[str, Any], optional): The inputs to the chain group. - project_name (str, optional): The name of the project. - Defaults to None. - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - run_id (UUID, optional): The ID of the run. - tags (List[str], optional): The inheritable tags to apply to all runs. - Defaults to None. - - Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. - - Returns: - CallbackManagerForChainGroup: The callback manager for the chain group. - - Example: - .. code-block:: python - - llm_input = "Foo" - with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: - # Use the callback manager for the chain group - res = llm.predict(llm_input, callbacks=manager) - manager.on_chain_end({"output": res}) - """ # noqa: E501 - cb = _get_trace_callbacks( - project_name, example_id, callback_manager=callback_manager - ) - cm = CallbackManager.configure( - inheritable_callbacks=cb, - inheritable_tags=tags, - ) - - run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) - child_cm = run_manager.get_child() - group_cm = CallbackManagerForChainGroup( - child_cm.handlers, - child_cm.inheritable_handlers, - child_cm.parent_run_id, - parent_run_manager=run_manager, - tags=child_cm.tags, - inheritable_tags=child_cm.inheritable_tags, - metadata=child_cm.metadata, - inheritable_metadata=child_cm.inheritable_metadata, - ) - try: - yield group_cm - except Exception as e: - if not group_cm.ended: - run_manager.on_chain_error(e) - raise e - else: - if not group_cm.ended: - run_manager.on_chain_end({}) - - -@asynccontextmanager -async def atrace_as_chain_group( - group_name: str, - callback_manager: Optional[AsyncCallbackManager] = None, - *, - inputs: Optional[Dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, -) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: - """Get an async callback manager for a chain group in a context manager. - Useful for grouping different async calls together as a single run even if - they aren't composed in a single chain. - - Args: - group_name (str): The name of the chain group. - callback_manager (AsyncCallbackManager, optional): The async callback manager to use, - which manages tracing and other callback behavior. - project_name (str, optional): The name of the project. - Defaults to None. - example_id (str or UUID, optional): The ID of the example. - Defaults to None. - run_id (UUID, optional): The ID of the run. - tags (List[str], optional): The inheritable tags to apply to all runs. - Defaults to None. - Returns: - AsyncCallbackManager: The async callback manager for the chain group. - - Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. - - Example: - .. code-block:: python - - llm_input = "Foo" - async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: - # Use the async callback manager for the chain group - res = await llm.apredict(llm_input, callbacks=manager) - await manager.on_chain_end({"output": res}) - """ # noqa: E501 - cb = _get_trace_callbacks( - project_name, example_id, callback_manager=callback_manager - ) - cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) - - run_manager = await cm.on_chain_start( - {"name": group_name}, inputs or {}, run_id=run_id - ) - child_cm = run_manager.get_child() - group_cm = AsyncCallbackManagerForChainGroup( - child_cm.handlers, - child_cm.inheritable_handlers, - child_cm.parent_run_id, - parent_run_manager=run_manager, - tags=child_cm.tags, - inheritable_tags=child_cm.inheritable_tags, - metadata=child_cm.metadata, - inheritable_metadata=child_cm.inheritable_metadata, - ) - try: - yield group_cm - except Exception as e: - if not group_cm.ended: - await run_manager.on_chain_error(e) - raise e - else: - if not group_cm.ended: - await run_manager.on_chain_end({}) - - -def handle_event( - handlers: List[BaseCallbackHandler], - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - """Generic event handler for CallbackManager. - - Note: This function is used by langserve to handle events. - - Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") - ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler - **kwargs: The keyword arguments to pass to the event handler - """ - coros: List[Coroutine[Any, Any, Any]] = [] - - try: - message_strings: Optional[List[str]] = None - for handler in handlers: - try: - if ignore_condition_name is None or not getattr( - handler, ignore_condition_name - ): - event = getattr(handler, event_name)(*args, **kwargs) - if asyncio.iscoroutine(event): - coros.append(event) - except NotImplementedError as e: - if event_name == "on_chat_model_start": - if message_strings is None: - message_strings = [get_buffer_string(m) for m in args[1]] - handle_event( - [handler], - "on_llm_start", - "ignore_llm", - args[0], - message_strings, - *args[2:], - **kwargs, - ) - else: - handler_name = handler.__class__.__name__ - logger.warning( - f"NotImplementedError in {handler_name}.{event_name}" - f" callback: {repr(e)}" - ) - except Exception as e: - logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback:" - f" {repr(e)}" - ) - if handler.raise_error: - raise e - finally: - if coros: - try: - # Raises RuntimeError if there is no current event loop. - asyncio.get_running_loop() - loop_running = True - except RuntimeError: - loop_running = False - - if loop_running: - # If we try to submit this coroutine to the running loop - # we end up in a deadlock, as we'd have gotten here from a - # running coroutine, which we cannot interrupt to run this one. - # The solution is to create a new loop in a new thread. - with ThreadPoolExecutor(1) as executor: - executor.submit(_run_coros, coros).result() - else: - _run_coros(coros) - - -def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: - if hasattr(asyncio, "Runner"): - # Python 3.11+ - # Run the coroutines in a new event loop, taking care to - # - install signal handlers - # - run pending tasks scheduled by `coros` - # - close asyncgens and executors - # - close the loop - with asyncio.Runner() as runner: - # Run the coroutine, get the result - for coro in coros: - runner.run(coro) - - # Run pending tasks scheduled by coros until they are all done - while pending := asyncio.all_tasks(runner.get_loop()): - runner.run(asyncio.wait(pending)) - else: - # Before Python 3.11 we need to run each coroutine in a new event loop - # as the Runner api is not available. - for coro in coros: - asyncio.run(coro) - - -async def _ahandle_event_for_handler( - handler: BaseCallbackHandler, - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - try: - if ignore_condition_name is None or not getattr(handler, ignore_condition_name): - event = getattr(handler, event_name) - if asyncio.iscoroutinefunction(event): - await event(*args, **kwargs) - else: - if handler.run_inline: - event(*args, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, functools.partial(event, *args, **kwargs) - ) - except NotImplementedError as e: - if event_name == "on_chat_model_start": - message_strings = [get_buffer_string(m) for m in args[1]] - await _ahandle_event_for_handler( - handler, - "on_llm_start", - "ignore_llm", - args[0], - message_strings, - *args[2:], - **kwargs, - ) - else: - logger.warning( - f"NotImplementedError in {handler.__class__.__name__}.{event_name}" - f" callback: {repr(e)}" - ) - except Exception as e: - logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback:" - f" {repr(e)}" - ) - if handler.raise_error: - raise e - - -async def ahandle_event( - handlers: List[BaseCallbackHandler], - event_name: str, - ignore_condition_name: Optional[str], - *args: Any, - **kwargs: Any, -) -> None: - """Generic event handler for AsyncCallbackManager. - - Note: This function is used by langserve to handle events. - - Args: - handlers: The list of handlers that will handle the event - event_name: The name of the event (e.g., "on_llm_start") - ignore_condition_name: Name of the attribute defined on handler - that if True will cause the handler to be skipped for the given event - *args: The arguments to pass to the event handler - **kwargs: The keyword arguments to pass to the event handler - """ - for handler in [h for h in handlers if h.run_inline]: - await _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs - ) - await asyncio.gather( - *( - _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs - ) - for handler in handlers - if not handler.run_inline - ) - ) - - -BRM = TypeVar("BRM", bound="BaseRunManager") - - -class BaseRunManager(RunManagerMixin): - """Base class for run manager (a bound callback manager).""" - - def __init__( - self, - *, - run_id: UUID, - handlers: List[BaseCallbackHandler], - inheritable_handlers: List[BaseCallbackHandler], - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize the run manager. - - Args: - run_id (UUID): The ID of the run. - handlers (List[BaseCallbackHandler]): The list of handlers. - inheritable_handlers (List[BaseCallbackHandler]): - The list of inheritable handlers. - parent_run_id (UUID, optional): The ID of the parent run. - Defaults to None. - tags (Optional[List[str]]): The list of tags. - inheritable_tags (Optional[List[str]]): The list of inheritable tags. - metadata (Optional[Dict[str, Any]]): The metadata. - inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. - """ - self.run_id = run_id - self.handlers = handlers - self.inheritable_handlers = inheritable_handlers - self.parent_run_id = parent_run_id - self.tags = tags or [] - self.inheritable_tags = inheritable_tags or [] - self.metadata = metadata or {} - self.inheritable_metadata = inheritable_metadata or {} - - @classmethod - def get_noop_manager(cls: Type[BRM]) -> BRM: - """Return a manager that doesn't perform any operations. - - Returns: - BaseRunManager: The noop manager. - """ - return cls( - run_id=uuid.uuid4(), - handlers=[], - inheritable_handlers=[], - tags=[], - inheritable_tags=[], - metadata={}, - inheritable_metadata={}, - ) - - -class RunManager(BaseRunManager): - """Sync Run Manager.""" - - def on_text( - self, - text: str, - **kwargs: Any, - ) -> Any: - """Run when text is received. - - Args: - text (str): The received text. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_text", - None, - text, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - handle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class ParentRunManager(RunManager): - """Sync Parent Run Manager.""" - - def get_child(self, tag: Optional[str] = None) -> CallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - CallbackManager: The child callback manager. - """ - manager = CallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - manager.add_metadata(self.inheritable_metadata) - if tag is not None: - manager.add_tags([tag], False) - return manager - - -class AsyncRunManager(BaseRunManager): - """Async Run Manager.""" - - async def on_text( - self, - text: str, - **kwargs: Any, - ) -> Any: - """Run when text is received. - - Args: - text (str): The received text. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_text", - None, - text, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - await ahandle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncParentRunManager(AsyncRunManager): - """Async Parent Run Manager.""" - - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - AsyncCallbackManager: The child callback manager. - """ - manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - manager.add_metadata(self.inheritable_metadata) - if tag is not None: - manager.add_tags([tag], False) - return manager - - -class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): - """Callback manager for LLM run.""" - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM generates a new token. - - Args: - token (str): The new token. - """ - handle_event( - self.handlers, - "on_llm_new_token", - "ignore_llm", - token=token, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - chunk=chunk, - **kwargs, - ) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running. - - Args: - response (LLMResult): The LLM result. - """ - handle_event( - self.handlers, - "on_llm_end", - "ignore_llm", - response, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_llm_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when LLM errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_llm_error", - "ignore_llm", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): - """Async callback manager for LLM run.""" - - async def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - **kwargs: Any, - ) -> None: - """Run when LLM generates a new token. - - Args: - token (str): The new token. - """ - await ahandle_event( - self.handlers, - "on_llm_new_token", - "ignore_llm", - token, - chunk=chunk, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running. - - Args: - response (LLMResult): The LLM result. - """ - await ahandle_event( - self.handlers, - "on_llm_end", - "ignore_llm", - response, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_llm_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when LLM errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_llm_error", - "ignore_llm", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): - """Callback manager for chain run.""" - - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: - """Run when chain ends running. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - handle_event( - self.handlers, - "on_chain_end", - "ignore_chain", - outputs, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_chain_error", - "ignore_chain", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run when agent action is received. - - Args: - action (AgentAction): The agent action. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_agent_action", - "ignore_agent", - action, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run when agent finish is received. - - Args: - finish (AgentFinish): The agent finish. - - Returns: - Any: The result of the callback. - """ - handle_event( - self.handlers, - "on_agent_finish", - "ignore_agent", - finish, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): - """Async callback manager for chain run.""" - - async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any - ) -> None: - """Run when chain ends running. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - await ahandle_event( - self.handlers, - "on_chain_end", - "ignore_chain", - outputs, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_chain_error", - "ignore_chain", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run when agent action is received. - - Args: - action (AgentAction): The agent action. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_agent_action", - "ignore_agent", - action, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run when agent finish is received. - - Args: - finish (AgentFinish): The agent finish. - - Returns: - Any: The result of the callback. - """ - await ahandle_event( - self.handlers, - "on_agent_finish", - "ignore_agent", - finish, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): - """Callback manager for tool run.""" - - def on_tool_end( - self, - output: str, - **kwargs: Any, - ) -> None: - """Run when tool ends running. - - Args: - output (str): The output of the tool. - """ - handle_event( - self.handlers, - "on_tool_end", - "ignore_agent", - output, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_tool_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when tool errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - handle_event( - self.handlers, - "on_tool_error", - "ignore_agent", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): - """Async callback manager for tool run.""" - - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running. - - Args: - output (str): The output of the tool. - """ - await ahandle_event( - self.handlers, - "on_tool_end", - "ignore_agent", - output, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_tool_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when tool errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - await ahandle_event( - self.handlers, - "on_tool_error", - "ignore_agent", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): - """Callback manager for retriever run.""" - - def on_retriever_end( - self, - documents: Sequence[Document], - **kwargs: Any, - ) -> None: - """Run when retriever ends running.""" - handle_event( - self.handlers, - "on_retriever_end", - "ignore_retriever", - documents, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - def on_retriever_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when retriever errors.""" - handle_event( - self.handlers, - "on_retriever_error", - "ignore_retriever", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class AsyncCallbackManagerForRetrieverRun( +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainGroup, + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForLLMRun, + AsyncCallbackManagerForRetrieverRun, + AsyncCallbackManagerForToolRun, AsyncParentRunManager, - RetrieverManagerMixin, -): - """Async callback manager for retriever run.""" - - async def on_retriever_end( - self, documents: Sequence[Document], **kwargs: Any - ) -> None: - """Run when retriever ends running.""" - await ahandle_event( - self.handlers, - "on_retriever_end", - "ignore_retriever", - documents, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - async def on_retriever_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when retriever errors.""" - await ahandle_event( - self.handlers, - "on_retriever_error", - "ignore_retriever", - error, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - - -class CallbackManager(BaseCallbackManager): - """Callback manager that handles callbacks from LangChain.""" - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - prompts (List[str]): The list of prompts. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[CallbackManagerForLLMRun]: A callback manager for each - prompt as an LLM run. - """ - managers = [] - for prompt in prompts: - run_id_ = uuid.uuid4() - handle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - [prompt], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - managers.append( - CallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - return managers - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - messages (List[List[BaseMessage]]): The list of messages. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[CallbackManagerForLLMRun]: A callback manager for each - list of messages as an LLM run. - """ - - managers = [] - for message_list in messages: - run_id_ = uuid.uuid4() - handle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - [message_list], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - managers.append( - CallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - return managers - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Union[Dict[str, Any], Any], - run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForChainRun: - """Run when chain starts running. - - Args: - serialized (Dict[str, Any]): The serialized chain. - inputs (Union[Dict[str, Any], Any]): The inputs to the chain. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - CallbackManagerForChainRun: The callback manager for the chain run. - """ - if run_id is None: - run_id = uuid.uuid4() - handle_event( - self.handlers, - "on_chain_start", - "ignore_chain", - serialized, - inputs, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForChainRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForToolRun: - """Run when tool starts running. - - Args: - serialized (Dict[str, Any]): The serialized tool. - input_str (str): The input to the tool. - run_id (UUID, optional): The ID of the run. Defaults to None. - parent_run_id (UUID, optional): The ID of the parent run. Defaults to None. - - Returns: - CallbackManagerForToolRun: The callback manager for the tool run. - """ - if run_id is None: - run_id = uuid.uuid4() - - handle_event( - self.handlers, - "on_tool_start", - "ignore_agent", - serialized, - input_str, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForToolRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> CallbackManagerForRetrieverRun: - """Run when retriever starts running.""" - if run_id is None: - run_id = uuid.uuid4() - - handle_event( - self.handlers, - "on_retriever_start", - "ignore_retriever", - serialized, - query, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return CallbackManagerForRetrieverRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @classmethod - def configure( - cls, - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, - ) -> CallbackManager: - """Configure the callback manager. - - Args: - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. - Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - CallbackManager: The configured callback manager. - """ - return _configure( - cls, - inheritable_callbacks, - local_callbacks, - verbose, - inheritable_tags, - local_tags, - inheritable_metadata, - local_metadata, - ) - - -class CallbackManagerForChainGroup(CallbackManager): - """Callback manager for the chain group.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - parent_run_manager: CallbackManagerForChainRun, - **kwargs: Any, - ) -> None: - super().__init__( - handlers, - inheritable_handlers, - parent_run_id, - **kwargs, - ) - self.parent_run_manager = parent_run_manager - self.ended = False - - def copy(self) -> CallbackManagerForChainGroup: - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - parent_run_manager=self.parent_run_manager, - ) - - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: - """Run when traced chain group ends. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - self.ended = True - return self.parent_run_manager.on_chain_end(outputs, **kwargs) - - def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - self.ended = True - return self.parent_run_manager.on_chain_error(error, **kwargs) - - -class AsyncCallbackManager(BaseCallbackManager): - """Async callback manager that handles callbacks from LangChain.""" - - @property - def is_async(self) -> bool: - """Return whether the handler is async.""" - return True - - async def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - prompts (List[str]): The list of prompts. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[AsyncCallbackManagerForLLMRun]: The list of async - callback managers, one for each LLM Run corresponding - to each prompt. - """ - - tasks = [] - managers = [] - - for prompt in prompts: - run_id_ = uuid.uuid4() - - tasks.append( - ahandle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - [prompt], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - ) - - managers.append( - AsyncCallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - await asyncio.gather(*tasks) - - return managers - - async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: - """Run when LLM starts running. - - Args: - serialized (Dict[str, Any]): The serialized LLM. - messages (List[List[BaseMessage]]): The list of messages. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - List[AsyncCallbackManagerForLLMRun]: The list of - async callback managers, one for each LLM Run - corresponding to each inner message list. - """ - tasks = [] - managers = [] - - for message_list in messages: - run_id_ = uuid.uuid4() - - tasks.append( - ahandle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - [message_list], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - ) - - managers.append( - AsyncCallbackManagerForLLMRun( - run_id=run_id_, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - ) - - await asyncio.gather(*tasks) - return managers - - async def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Union[Dict[str, Any], Any], - run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForChainRun: - """Run when chain starts running. - - Args: - serialized (Dict[str, Any]): The serialized chain. - inputs (Union[Dict[str, Any], Any]): The inputs to the chain. - run_id (UUID, optional): The ID of the run. Defaults to None. - - Returns: - AsyncCallbackManagerForChainRun: The async callback manager - for the chain run. - """ - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_chain_start", - "ignore_chain", - serialized, - inputs, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForChainRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForToolRun: - """Run when tool starts running. - - Args: - serialized (Dict[str, Any]): The serialized tool. - input_str (str): The input to the tool. - run_id (UUID, optional): The ID of the run. Defaults to None. - parent_run_id (UUID, optional): The ID of the parent run. - Defaults to None. - - Returns: - AsyncCallbackManagerForToolRun: The async callback manager - for the tool run. - """ - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_tool_start", - "ignore_agent", - serialized, - input_str, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForToolRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - async def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> AsyncCallbackManagerForRetrieverRun: - """Run when retriever starts running.""" - if run_id is None: - run_id = uuid.uuid4() - - await ahandle_event( - self.handlers, - "on_retriever_start", - "ignore_retriever", - serialized, - query, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, - ) - - return AsyncCallbackManagerForRetrieverRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - ) - - @classmethod - def configure( - cls, - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, - ) -> AsyncCallbackManager: - """Configure the async callback manager. - - Args: - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. - Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - AsyncCallbackManager: The configured async callback manager. - """ - return _configure( - cls, - inheritable_callbacks, - local_callbacks, - verbose, - inheritable_tags, - local_tags, - inheritable_metadata, - local_metadata, - ) - - -class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): - """Async callback manager for the chain group.""" - - def __init__( - self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, - *, - parent_run_manager: AsyncCallbackManagerForChainRun, - **kwargs: Any, - ) -> None: - super().__init__( - handlers, - inheritable_handlers, - parent_run_id, - **kwargs, - ) - self.parent_run_manager = parent_run_manager - self.ended = False - - def copy(self) -> AsyncCallbackManagerForChainGroup: - return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, - parent_run_manager=self.parent_run_manager, - ) - - async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any - ) -> None: - """Run when traced chain group ends. - - Args: - outputs (Union[Dict[str, Any], Any]): The outputs of the chain. - """ - self.ended = True - await self.parent_run_manager.on_chain_end(outputs, **kwargs) - - async def on_chain_error( - self, - error: BaseException, - **kwargs: Any, - ) -> None: - """Run when chain errors. - - Args: - error (Exception or KeyboardInterrupt): The error. - """ - self.ended = True - await self.parent_run_manager.on_chain_error(error, **kwargs) - - -T = TypeVar("T", CallbackManager, AsyncCallbackManager) - - -def env_var_is_set(env_var: str) -> bool: - """Check if an environment variable is set. - - Args: - env_var (str): The name of the environment variable. - - Returns: - bool: True if the environment variable is set, False otherwise. - """ - return env_var in os.environ and os.environ[env_var] not in ( - "", - "0", - "false", - "False", - ) - - -def _tracing_v2_is_enabled() -> bool: - return ( - env_var_is_set("LANGCHAIN_TRACING_V2") - or tracing_v2_callback_var.get() is not None - or get_run_tree_context() is not None - ) - - -def _get_tracer_project() -> str: - run_tree = get_run_tree_context() - return getattr( - run_tree, - "session_name", - getattr( - # Note, if people are trying to nest @traceable functions and the - # tracing_v2_enabled context manager, this will likely mess up the - # tree structure. - tracing_v2_callback_var.get(), - "project", - # Have to set this to a string even though it always will return - # a string because `get_tracer_project` technically can return - # None, but only when a specific argument is supplied. - # Therefore, this just tricks the mypy type checker - str(ls_utils.get_tracer_project()), - ), - ) - - -_configure_hooks: List[ - Tuple[ - ContextVar[Optional[BaseCallbackHandler]], - bool, - Optional[Type[BaseCallbackHandler]], - Optional[str], - ] -] = [] - -H = TypeVar("H", bound=BaseCallbackHandler, covariant=True) - - -def register_configure_hook( - context_var: ContextVar[Optional[Any]], - inheritable: bool, - handle_class: Optional[Type[BaseCallbackHandler]] = None, - env_var: Optional[str] = None, -) -> None: - if env_var is not None and handle_class is None: - raise ValueError( - "If env_var is set, handle_class must also be set to a non-None value." - ) - _configure_hooks.append( - ( - # the typings of ContextVar do not have the generic arg set as covariant - # so we have to cast it - cast(ContextVar[Optional[BaseCallbackHandler]], context_var), - inheritable, - handle_class, - env_var, - ) - ) - - -register_configure_hook(run_collector_var, False) - - -def _configure( - callback_manager_cls: Type[T], - inheritable_callbacks: Callbacks = None, - local_callbacks: Callbacks = None, - verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, -) -> T: - """Configure the callback manager. - - Args: - callback_manager_cls (Type[T]): The callback manager class. - inheritable_callbacks (Optional[Callbacks], optional): The inheritable - callbacks. Defaults to None. - local_callbacks (Optional[Callbacks], optional): The local callbacks. - Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. Defaults to False. - inheritable_tags (Optional[List[str]], optional): The inheritable tags. - Defaults to None. - local_tags (Optional[List[str]], optional): The local tags. Defaults to None. - inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable - metadata. Defaults to None. - local_metadata (Optional[Dict[str, Any]], optional): The local metadata. - Defaults to None. - - Returns: - T: The configured callback manager. - """ - run_tree = get_run_tree_context() - parent_run_id = None if run_tree is None else getattr(run_tree, "id") - callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id) - if inheritable_callbacks or local_callbacks: - if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: - inheritable_callbacks_ = inheritable_callbacks or [] - callback_manager = callback_manager_cls( - handlers=inheritable_callbacks_.copy(), - inheritable_handlers=inheritable_callbacks_.copy(), - parent_run_id=parent_run_id, - ) - else: - callback_manager = callback_manager_cls( - handlers=inheritable_callbacks.handlers.copy(), - inheritable_handlers=inheritable_callbacks.inheritable_handlers.copy(), - parent_run_id=inheritable_callbacks.parent_run_id, - tags=inheritable_callbacks.tags.copy(), - inheritable_tags=inheritable_callbacks.inheritable_tags.copy(), - metadata=inheritable_callbacks.metadata.copy(), - inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(), - ) - local_handlers_ = ( - local_callbacks - if isinstance(local_callbacks, list) - else (local_callbacks.handlers if local_callbacks else []) - ) - for handler in local_handlers_: - callback_manager.add_handler(handler, False) - if inheritable_tags or local_tags: - callback_manager.add_tags(inheritable_tags or []) - callback_manager.add_tags(local_tags or [], False) - if inheritable_metadata or local_metadata: - callback_manager.add_metadata(inheritable_metadata or {}) - callback_manager.add_metadata(local_metadata or {}, False) - - tracer = tracing_callback_var.get() - tracing_enabled_ = ( - env_var_is_set("LANGCHAIN_TRACING") - or tracer is not None - or env_var_is_set("LANGCHAIN_HANDLER") - ) - - tracer_v2 = tracing_v2_callback_var.get() - tracing_v2_enabled_ = _tracing_v2_is_enabled() - tracer_project = _get_tracer_project() - debug = _get_debug() - if verbose or debug or tracing_enabled_ or tracing_v2_enabled_: - if verbose and not any( - isinstance(handler, StdOutCallbackHandler) - for handler in callback_manager.handlers - ): - if debug: - pass - else: - callback_manager.add_handler(StdOutCallbackHandler(), False) - if debug and not any( - isinstance(handler, ConsoleCallbackHandler) - for handler in callback_manager.handlers - ): - callback_manager.add_handler(ConsoleCallbackHandler(), True) - if tracing_enabled_ and not any( - isinstance(handler, LangChainTracerV1) - for handler in callback_manager.handlers - ): - if tracer: - callback_manager.add_handler(tracer, True) - else: - handler = LangChainTracerV1() - handler.load_session(tracer_project) - callback_manager.add_handler(handler, True) - if tracing_v2_enabled_ and not any( - isinstance(handler, LangChainTracer) - for handler in callback_manager.handlers - ): - if tracer_v2: - callback_manager.add_handler(tracer_v2, True) - else: - try: - handler = LangChainTracer(project_name=tracer_project) - callback_manager.add_handler(handler, True) - except Exception as e: - logger.warning( - "Unable to load requested LangChainTracer." - " To disable this warning," - " unset the LANGCHAIN_TRACING_V2 environment variables.", - e, - ) - for var, inheritable, handler_class, env_var in _configure_hooks: - create_one = ( - env_var is not None - and env_var_is_set(env_var) - and handler_class is not None - ) - if var.get() is not None or create_one: - var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)() - if handler_class is None: - if not any( - handler is var_handler # direct pointer comparison - for handler in callback_manager.handlers - ): - callback_manager.add_handler(var_handler, inheritable) - else: - if not any( - isinstance(handler, handler_class) - for handler in callback_manager.handlers - ): - callback_manager.add_handler(var_handler, inheritable) - return callback_manager + AsyncRunManager, + BaseRunManager, + CallbackManager, + CallbackManagerForChainGroup, + CallbackManagerForChainRun, + CallbackManagerForLLMRun, + CallbackManagerForRetrieverRun, + CallbackManagerForToolRun, + ParentRunManager, + RunManager, + collect_runs, + env_var_is_set, + handle_event, + register_configure_hook, + trace_as_chain_group, + tracing_enabled, + tracing_v2_enabled, +) + +__all__ = [ + "tracing_enabled", + "tracing_v2_enabled", + "collect_runs", + "trace_as_chain_group", + "handle_event", + "BaseRunManager", + "RunManager", + "ParentRunManager", + "AsyncRunManager", + "AsyncParentRunManager", + "CallbackManagerForLLMRun", + "AsyncCallbackManagerForLLMRun", + "CallbackManagerForChainRun", + "AsyncCallbackManagerForChainRun", + "CallbackManagerForToolRun", + "AsyncCallbackManagerForToolRun", + "CallbackManagerForRetrieverRun", + "AsyncCallbackManagerForRetrieverRun", + "CallbackManager", + "CallbackManagerForChainGroup", + "AsyncCallbackManager", + "AsyncCallbackManagerForChainGroup", + "env_var_is_set", + "register_configure_hook", +] diff --git a/libs/langchain/langchain/schema/callbacks/stdout.py b/libs/langchain/langchain/schema/callbacks/stdout.py index 63e71a30dd5..754e58248e4 100644 --- a/libs/langchain/langchain/schema/callbacks/stdout.py +++ b/libs/langchain/langchain/schema/callbacks/stdout.py @@ -1,97 +1,3 @@ -"""Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional +from langchain_core.callbacks.stdout import StdOutCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.schema.callbacks.base import BaseCallbackHandler -from langchain.utils.input import print_text - - -class StdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - class_name = serialized.get("name", serialized.get("id", [""])[-1]) - print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print("\n\033[1m> Finished chain.\033[0m") - - def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - print_text(action.log, color=color or self.color) - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - if observation_prefix is not None: - print_text(f"\n{observation_prefix}") - print_text(output, color=color or self.color) - if llm_prefix is not None: - print_text(f"\n{llm_prefix}") - - def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Any, - ) -> None: - """Run when agent ends.""" - print_text(text, color=color or self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text(finish.log, color=color or self.color, end="\n") +__all__ = ["StdOutCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/streaming_stdout.py b/libs/langchain/langchain/schema/callbacks/streaming_stdout.py index dd0896801a8..35608689634 100644 --- a/libs/langchain/langchain/schema/callbacks/streaming_stdout.py +++ b/libs/langchain/langchain/schema/callbacks/streaming_stdout.py @@ -1,67 +1,3 @@ -"""Callback Handler streams to stdout on new llm token.""" -import sys -from typing import Any, Dict, List +from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.schema.callbacks.base import BaseCallbackHandler -from langchain.schema.messages import BaseMessage - - -class StreamingStdOutCallbackHandler(BaseCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - **kwargs: Any, - ) -> None: - """Run when LLM starts running.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - - def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when LLM errors.""" - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - - def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when chain errors.""" - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - pass - - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - - def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: - """Run when tool errors.""" - - def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" +__all__ = ["StreamingStdOutCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/base.py b/libs/langchain/langchain/schema/callbacks/tracers/base.py index 4f136a1911d..8f9e3d61578 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/base.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/base.py @@ -1,537 +1,3 @@ -"""Base interfaces for tracing runs.""" -from __future__ import annotations +from langchain_core.callbacks.tracers.base import BaseTracer, TracerException -import logging -from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Union, cast -from uuid import UUID - -from tenacity import RetryCallState - -from langchain.load.dump import dumpd -from langchain.schema.callbacks.base import BaseCallbackHandler -from langchain.schema.callbacks.tracers.schemas import Run -from langchain.schema.document import Document -from langchain.schema.output import ( - ChatGeneration, - ChatGenerationChunk, - GenerationChunk, - LLMResult, -) - -logger = logging.getLogger(__name__) - - -class TracerException(Exception): - """Base class for exceptions in tracers module.""" - - -class BaseTracer(BaseCallbackHandler, ABC): - """Base interface for tracers.""" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.run_map: Dict[str, Run] = {} - - @staticmethod - def _add_child_run( - parent_run: Run, - child_run: Run, - ) -> None: - """Add child run to a chain run or tool run.""" - parent_run.child_runs.append(child_run) - - @abstractmethod - def _persist_run(self, run: Run) -> None: - """Persist a run.""" - - def _start_trace(self, run: Run) -> None: - """Start a trace for a run.""" - if run.parent_run_id: - parent_run = self.run_map.get(str(run.parent_run_id)) - if parent_run: - self._add_child_run(parent_run, run) - parent_run.child_execution_order = max( - parent_run.child_execution_order, run.child_execution_order - ) - else: - logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") - self.run_map[str(run.id)] = run - self._on_run_create(run) - - def _end_trace(self, run: Run) -> None: - """End a trace for a run.""" - if not run.parent_run_id: - self._persist_run(run) - else: - parent_run = self.run_map.get(str(run.parent_run_id)) - if parent_run is None: - logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") - elif ( - run.child_execution_order is not None - and parent_run.child_execution_order is not None - and run.child_execution_order > parent_run.child_execution_order - ): - parent_run.child_execution_order = run.child_execution_order - self.run_map.pop(str(run.id)) - self._on_run_update(run) - - def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: - """Get the execution order for a run.""" - if parent_run_id is None: - return 1 - - parent_run = self.run_map.get(parent_run_id) - if parent_run is None: - logger.debug(f"Parent run with UUID {parent_run_id} not found.") - return 1 - if parent_run.child_execution_order is None: - raise TracerException( - f"Parent run with UUID {parent_run_id} has no child execution order." - ) - - return parent_run.child_execution_order + 1 - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for an LLM run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - llm_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"prompts": prompts}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - run_type="llm", - tags=tags or [], - name=name, - ) - self._start_trace(llm_run) - self._on_llm_start(llm_run) - return llm_run - - def on_llm_new_token( - self, - token: str, - *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Run: - """Run on new LLM token. Only available when streaming is enabled.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_new_token callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - event_kwargs: Dict[str, Any] = {"token": token} - if chunk: - event_kwargs["chunk"] = chunk - llm_run.events.append( - { - "name": "new_token", - "time": datetime.utcnow(), - "kwargs": event_kwargs, - }, - ) - self._on_llm_new_token(llm_run, token, chunk) - return llm_run - - def on_retry( - self, - retry_state: RetryCallState, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - if not run_id: - raise TracerException("No run_id provided for on_retry callback.") - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None: - raise TracerException("No Run found to be traced for on_retry") - retry_d: Dict[str, Any] = { - "slept": retry_state.idle_for, - "attempt": retry_state.attempt_number, - } - if retry_state.outcome is None: - retry_d["outcome"] = "N/A" - elif retry_state.outcome.failed: - retry_d["outcome"] = "failed" - exception = retry_state.outcome.exception() - retry_d["exception"] = str(exception) - retry_d["exception_type"] = exception.__class__.__name__ - else: - retry_d["outcome"] = "success" - retry_d["result"] = str(retry_state.outcome.result()) - llm_run.events.append( - { - "name": "retry", - "time": datetime.utcnow(), - "kwargs": retry_d, - }, - ) - return llm_run - - def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: - """End a trace for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_end callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - llm_run.outputs = response.dict() - for i, generations in enumerate(response.generations): - for j, generation in enumerate(generations): - output_generation = llm_run.outputs["generations"][i][j] - if "message" in output_generation: - output_generation["message"] = dumpd( - cast(ChatGeneration, generation).message - ) - llm_run.end_time = datetime.utcnow() - llm_run.events.append({"name": "end", "time": llm_run.end_time}) - self._end_trace(llm_run) - self._on_llm_end(llm_run) - return llm_run - - def on_llm_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for an LLM run.""" - if not run_id: - raise TracerException("No run_id provided for on_llm_error callback.") - - run_id_ = str(run_id) - llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != "llm": - raise TracerException(f"No LLM Run found to be traced for {run_id}") - llm_run.error = repr(error) - llm_run.end_time = datetime.utcnow() - llm_run.events.append({"name": "error", "time": llm_run.end_time}) - self._end_trace(llm_run) - self._on_chain_error(llm_run) - return llm_run - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for a chain run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - chain_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs=inputs if isinstance(inputs, dict) else {"input": inputs}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - child_runs=[], - run_type=run_type or "chain", - name=name, - tags=tags or [], - ) - self._start_trace(chain_run) - self._on_chain_start(chain_run) - return chain_run - - def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Run: - """End a trace for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_end callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - - chain_run.outputs = ( - outputs if isinstance(outputs, dict) else {"output": outputs} - ) - chain_run.end_time = datetime.utcnow() - chain_run.events.append({"name": "end", "time": chain_run.end_time}) - if inputs is not None: - chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} - self._end_trace(chain_run) - self._on_chain_end(chain_run) - return chain_run - - def on_chain_error( - self, - error: BaseException, - *, - inputs: Optional[Dict[str, Any]] = None, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_error callback.") - chain_run = self.run_map.get(str(run_id)) - if chain_run is None: - raise TracerException(f"No chain Run found to be traced for {run_id}") - - chain_run.error = repr(error) - chain_run.end_time = datetime.utcnow() - chain_run.events.append({"name": "error", "time": chain_run.end_time}) - if inputs is not None: - chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} - self._end_trace(chain_run) - self._on_chain_error(chain_run) - return chain_run - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Start a trace for a tool run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - tool_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"input": input_str}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - child_runs=[], - run_type="tool", - tags=tags or [], - name=name, - ) - self._start_trace(tool_run) - self._on_tool_start(tool_run) - return tool_run - - def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: - """End a trace for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_end callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - - tool_run.outputs = {"output": output} - tool_run.end_time = datetime.utcnow() - tool_run.events.append({"name": "end", "time": tool_run.end_time}) - self._end_trace(tool_run) - self._on_tool_end(tool_run) - return tool_run - - def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Handle an error for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_error callback.") - tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != "tool": - raise TracerException(f"No tool Run found to be traced for {run_id}") - - tool_run.error = repr(error) - tool_run.end_time = datetime.utcnow() - tool_run.events.append({"name": "error", "time": tool_run.end_time}) - self._end_trace(tool_run) - self._on_tool_error(tool_run) - return tool_run - - def on_retriever_start( - self, - serialized: Dict[str, Any], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> Run: - """Run when Retriever starts running.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - retrieval_run = Run( - id=run_id, - name=name or "Retriever", - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"query": query}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - tags=tags, - child_runs=[], - run_type="retriever", - ) - self._start_trace(retrieval_run) - self._on_retriever_start(retrieval_run) - return retrieval_run - - def on_retriever_error( - self, - error: BaseException, - *, - run_id: UUID, - **kwargs: Any, - ) -> Run: - """Run when Retriever errors.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_error callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") - - retrieval_run.error = repr(error) - retrieval_run.end_time = datetime.utcnow() - retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) - self._end_trace(retrieval_run) - self._on_retriever_error(retrieval_run) - return retrieval_run - - def on_retriever_end( - self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any - ) -> Run: - """Run when Retriever ends running.""" - if not run_id: - raise TracerException("No run_id provided for on_retriever_end callback.") - retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != "retriever": - raise TracerException(f"No retriever Run found to be traced for {run_id}") - retrieval_run.outputs = {"documents": documents} - retrieval_run.end_time = datetime.utcnow() - retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) - self._end_trace(retrieval_run) - self._on_retriever_end(retrieval_run) - return retrieval_run - - def __deepcopy__(self, memo: dict) -> BaseTracer: - """Deepcopy the tracer.""" - return self - - def __copy__(self) -> BaseTracer: - """Copy the tracer.""" - return self - - def _on_run_create(self, run: Run) -> None: - """Process a run upon creation.""" - - def _on_run_update(self, run: Run) -> None: - """Process a run upon update.""" - - def _on_llm_start(self, run: Run) -> None: - """Process the LLM Run upon start.""" - - def _on_llm_new_token( - self, - run: Run, - token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], - ) -> None: - """Process new LLM token.""" - - def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" - - def _on_llm_error(self, run: Run) -> None: - """Process the LLM Run upon error.""" - - def _on_chain_start(self, run: Run) -> None: - """Process the Chain Run upon start.""" - - def _on_chain_end(self, run: Run) -> None: - """Process the Chain Run.""" - - def _on_chain_error(self, run: Run) -> None: - """Process the Chain Run upon error.""" - - def _on_tool_start(self, run: Run) -> None: - """Process the Tool Run upon start.""" - - def _on_tool_end(self, run: Run) -> None: - """Process the Tool Run.""" - - def _on_tool_error(self, run: Run) -> None: - """Process the Tool Run upon error.""" - - def _on_chat_model_start(self, run: Run) -> None: - """Process the Chat Model Run upon start.""" - - def _on_retriever_start(self, run: Run) -> None: - """Process the Retriever Run upon start.""" - - def _on_retriever_end(self, run: Run) -> None: - """Process the Retriever Run.""" - - def _on_retriever_error(self, run: Run) -> None: - """Process the Retriever Run upon error.""" +__all__ = ["TracerException", "BaseTracer"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py b/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py index eac08e6c0e2..c847c53d285 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/evaluation.py @@ -1,222 +1,6 @@ -"""A tracer that runs evaluators over completed runs.""" -from __future__ import annotations +from langchain_core.callbacks.tracers.evaluation import ( + EvaluatorCallbackHandler, + wait_for_all_evaluators, +) -import logging -import threading -import weakref -from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from uuid import UUID - -import langsmith -from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults - -from langchain.schema.callbacks import manager -from langchain.schema.callbacks.tracers import langchain as langchain_tracer -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.langchain import _get_executor -from langchain.schema.callbacks.tracers.schemas import Run - -logger = logging.getLogger(__name__) - -_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() - - -def wait_for_all_evaluators() -> None: - """Wait for all tracers to finish.""" - global _TRACERS - for tracer in list(_TRACERS): - if tracer is not None: - tracer.wait_for_futures() - - -class EvaluatorCallbackHandler(BaseTracer): - """A tracer that runs a run evaluator whenever a run is persisted. - - Parameters - ---------- - evaluators : Sequence[RunEvaluator] - The run evaluators to apply to all top level runs. - client : LangSmith Client, optional - The LangSmith client instance to use for evaluating the runs. - If not specified, a new instance will be created. - example_id : Union[UUID, str], optional - The example ID to be associated with the runs. - project_name : str, optional - The LangSmith project name to be organize eval chain runs under. - - Attributes - ---------- - example_id : Union[UUID, None] - The example ID associated with the runs. - client : Client - The LangSmith client instance used for evaluating the runs. - evaluators : Sequence[RunEvaluator] - The sequence of run evaluators to be executed. - executor : ThreadPoolExecutor - The thread pool executor used for running the evaluators. - futures : Set[Future] - The set of futures representing the running evaluators. - skip_unfinished : bool - Whether to skip runs that are not finished or raised - an error. - project_name : Optional[str] - The LangSmith project name to be organize eval chain runs under. - """ - - name = "evaluator_callback_handler" - - def __init__( - self, - evaluators: Sequence[langsmith.RunEvaluator], - client: Optional[langsmith.Client] = None, - example_id: Optional[Union[UUID, str]] = None, - skip_unfinished: bool = True, - project_name: Optional[str] = "evaluators", - max_concurrency: Optional[int] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.client = client or langchain_tracer.get_client() - self.evaluators = evaluators - if max_concurrency is None: - self.executor: Optional[ThreadPoolExecutor] = _get_executor() - elif max_concurrency > 0: - self.executor = ThreadPoolExecutor(max_workers=max_concurrency) - weakref.finalize( - self, - lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True), - ) - else: - self.executor = None - self.futures: weakref.WeakSet[Future] = weakref.WeakSet() - self.skip_unfinished = skip_unfinished - self.project_name = project_name - self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} - self.lock = threading.Lock() - global _TRACERS - _TRACERS.add(self) - - def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None: - """Evaluate the run in the project. - - Parameters - ---------- - run : Run - The run to be evaluated. - evaluator : RunEvaluator - The evaluator to use for evaluating the run. - - """ - try: - if self.project_name is None: - eval_result = self.client.evaluate_run(run, evaluator) - eval_results = [eval_result] - with manager.tracing_v2_enabled( - project_name=self.project_name, tags=["eval"], client=self.client - ) as cb: - reference_example = ( - self.client.read_example(run.reference_example_id) - if run.reference_example_id - else None - ) - evaluation_result = evaluator.evaluate_run( - run, - example=reference_example, - ) - eval_results = self._log_evaluation_feedback( - evaluation_result, - run, - source_run_id=cb.latest_run.id if cb.latest_run else None, - ) - except Exception as e: - logger.error( - f"Error evaluating run {run.id} with " - f"{evaluator.__class__.__name__}: {repr(e)}", - exc_info=True, - ) - raise e - example_id = str(run.reference_example_id) - with self.lock: - for res in eval_results: - run_id = ( - str(getattr(res, "target_run_id")) - if hasattr(res, "target_run_id") - else str(run.id) - ) - self.logged_eval_results.setdefault((run_id, example_id), []).append( - res - ) - - def _select_eval_results( - self, - results: Union[EvaluationResult, EvaluationResults], - ) -> List[EvaluationResult]: - if isinstance(results, EvaluationResult): - results_ = [results] - elif isinstance(results, dict) and "results" in results: - results_ = cast(List[EvaluationResult], results["results"]) - else: - raise TypeError( - f"Invalid evaluation result type {type(results)}." - " Expected EvaluationResult or EvaluationResults." - ) - return results_ - - def _log_evaluation_feedback( - self, - evaluator_response: Union[EvaluationResult, EvaluationResults], - run: Run, - source_run_id: Optional[UUID] = None, - ) -> List[EvaluationResult]: - results = self._select_eval_results(evaluator_response) - for res in results: - source_info_: Dict[str, Any] = {} - if res.evaluator_info: - source_info_ = {**res.evaluator_info, **source_info_} - run_id_ = ( - getattr(res, "target_run_id") - if hasattr(res, "target_run_id") and res.target_run_id is not None - else run.id - ) - self.client.create_feedback( - run_id_, - res.key, - score=res.score, - value=res.value, - comment=res.comment, - correction=res.correction, - source_info=source_info_, - source_run_id=res.source_run_id or source_run_id, - feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL, - ) - return results - - def _persist_run(self, run: Run) -> None: - """Run the evaluator on the run. - - Parameters - ---------- - run : Run - The run to be evaluated. - - """ - if self.skip_unfinished and not run.outputs: - logger.debug(f"Skipping unfinished run {run.id}") - return - run_ = run.copy() - run_.reference_example_id = self.example_id - for evaluator in self.evaluators: - if self.executor is None: - self._evaluate_in_project(run_, evaluator) - else: - self.futures.add( - self.executor.submit(self._evaluate_in_project, run_, evaluator) - ) - - def wait_for_futures(self) -> None: - """Wait for all futures to complete.""" - wait(self.futures) +__all__ = ["wait_for_all_evaluators", "EvaluatorCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/langchain.py b/libs/langchain/langchain/schema/callbacks/tracers/langchain.py index c109c8b2a9d..284ae533f9a 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/langchain.py @@ -1,262 +1,8 @@ -"""A Tracer implementation that records to LangChain endpoint.""" -from __future__ import annotations - -import logging -import weakref -from concurrent.futures import Future, ThreadPoolExecutor, wait -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Union -from uuid import UUID - -from langsmith import Client -from langsmith import utils as ls_utils -from tenacity import ( - Retrying, - retry_if_exception_type, - stop_after_attempt, - wait_exponential_jitter, +from langchain_core.callbacks.tracers.langchain import ( + LangChainTracer, + get_client, + log_error_once, + wait_for_all_tracers, ) -from langchain.env import get_runtime_environment -from langchain.load.dump import dumpd -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import Run -from langchain.schema.messages import BaseMessage - -logger = logging.getLogger(__name__) -_LOGGED = set() -_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() -_CLIENT: Optional[Client] = None -_EXECUTOR: Optional[ThreadPoolExecutor] = None - - -def log_error_once(method: str, exception: Exception) -> None: - """Log an error once.""" - global _LOGGED - if (method, type(exception)) in _LOGGED: - return - _LOGGED.add((method, type(exception))) - logger.error(exception) - - -def wait_for_all_tracers() -> None: - """Wait for all tracers to finish.""" - global _TRACERS - for tracer in list(_TRACERS): - if tracer is not None: - tracer.wait_for_futures() - - -def get_client() -> Client: - """Get the client.""" - global _CLIENT - if _CLIENT is None: - _CLIENT = Client() - return _CLIENT - - -def _get_executor() -> ThreadPoolExecutor: - """Get the executor.""" - global _EXECUTOR - if _EXECUTOR is None: - _EXECUTOR = ThreadPoolExecutor() - return _EXECUTOR - - -def _copy(run: Run) -> Run: - """Copy a run.""" - try: - return run.copy(deep=True) - except TypeError: - # Fallback in case the object contains a lock or other - # non-pickleable object - return run.copy() - - -class LangChainTracer(BaseTracer): - """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - - def __init__( - self, - example_id: Optional[Union[UUID, str]] = None, - project_name: Optional[str] = None, - client: Optional[Client] = None, - tags: Optional[List[str]] = None, - use_threading: bool = True, - **kwargs: Any, - ) -> None: - """Initialize the LangChain tracer.""" - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.project_name = project_name or ls_utils.get_tracer_project() - self.client = client or get_client() - self._futures: weakref.WeakSet[Future] = weakref.WeakSet() - self.tags = tags or [] - self.executor = _get_executor() if use_threading else None - self.latest_run: Optional[Run] = None - global _TRACERS - _TRACERS.add(self) - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], - *, - run_id: UUID, - tags: Optional[List[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Start a trace for an LLM run.""" - parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id_) - start_time = datetime.utcnow() - if metadata: - kwargs.update({"metadata": metadata}) - chat_model_run = Run( - id=run_id, - parent_run_id=parent_run_id, - serialized=serialized, - inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, - extra=kwargs, - events=[{"name": "start", "time": start_time}], - start_time=start_time, - execution_order=execution_order, - child_execution_order=execution_order, - run_type="llm", - tags=tags, - name=name, - ) - self._start_trace(chat_model_run) - self._on_chat_model_start(chat_model_run) - - def _persist_run(self, run: Run) -> None: - run_ = run.copy() - run_.reference_example_id = self.example_id - self.latest_run = run_ - - def get_run_url(self) -> str: - """Get the LangSmith root run URL""" - if not self.latest_run: - raise ValueError("No traced run found.") - # If this is the first run in a project, the project may not yet be created. - # This method is only really useful for debugging flows, so we will assume - # there is some tolerace for latency. - for attempt in Retrying( - stop=stop_after_attempt(5), - wait=wait_exponential_jitter(), - retry=retry_if_exception_type(ls_utils.LangSmithError), - ): - with attempt: - return self.client.get_run_url( - run=self.latest_run, project_name=self.project_name - ) - raise ValueError("Failed to get run URL.") - - def _get_tags(self, run: Run) -> List[str]: - """Get combined tags for a run.""" - tags = set(run.tags or []) - tags.update(self.tags or []) - return list(tags) - - def _persist_run_single(self, run: Run) -> None: - """Persist a run.""" - run_dict = run.dict(exclude={"child_runs"}) - run_dict["tags"] = self._get_tags(run) - extra = run_dict.get("extra", {}) - extra["runtime"] = get_runtime_environment() - run_dict["extra"] = extra - try: - self.client.create_run(**run_dict, project_name=self.project_name) - except Exception as e: - # Errors are swallowed by the thread executor so we need to log them here - log_error_once("post", e) - raise - - def _update_run_single(self, run: Run) -> None: - """Update a run.""" - try: - run_dict = run.dict() - run_dict["tags"] = self._get_tags(run) - self.client.update_run(run.id, **run_dict) - except Exception as e: - # Errors are swallowed by the thread executor so we need to log them here - log_error_once("patch", e) - raise - - def _submit(self, function: Callable[[Run], None], run: Run) -> None: - """Submit a function to the executor.""" - if self.executor is None: - function(run) - else: - self._futures.add(self.executor.submit(function, run)) - - def _on_llm_start(self, run: Run) -> None: - """Persist an LLM run.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_chat_model_start(self, run: Run) -> None: - """Persist an LLM run.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_llm_error(self, run: Run) -> None: - """Process the LLM Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_chain_start(self, run: Run) -> None: - """Process the Chain Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_chain_end(self, run: Run) -> None: - """Process the Chain Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_chain_error(self, run: Run) -> None: - """Process the Chain Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_tool_start(self, run: Run) -> None: - """Process the Tool Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_tool_end(self, run: Run) -> None: - """Process the Tool Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_tool_error(self, run: Run) -> None: - """Process the Tool Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_retriever_start(self, run: Run) -> None: - """Process the Retriever Run upon start.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id - self._submit(self._persist_run_single, _copy(run)) - - def _on_retriever_end(self, run: Run) -> None: - """Process the Retriever Run.""" - self._submit(self._update_run_single, _copy(run)) - - def _on_retriever_error(self, run: Run) -> None: - """Process the Retriever Run upon error.""" - self._submit(self._update_run_single, _copy(run)) - - def wait_for_futures(self) -> None: - """Wait for the given futures to complete.""" - wait(self._futures) +__all__ = ["log_error_once", "wait_for_all_tracers", "get_client", "LangChainTracer"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py index 957ae85875d..96154af452b 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/langchain_v1.py @@ -1,185 +1,3 @@ -from __future__ import annotations +from langchain_core.callbacks.tracers.langchain_v1 import LangChainTracerV1, get_headers -import logging -import os -from typing import Any, Dict, Optional, Union - -import requests - -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import ( - ChainRun, - LLMRun, - Run, - ToolRun, - TracerSession, - TracerSessionV1, - TracerSessionV1Base, -) -from langchain.schema.messages import get_buffer_string -from langchain.utils import raise_for_status_with_text - -logger = logging.getLogger(__name__) - - -def get_headers() -> Dict[str, Any]: - """Get the headers for the LangChain API.""" - headers: Dict[str, Any] = {"Content-Type": "application/json"} - if os.getenv("LANGCHAIN_API_KEY"): - headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") - return headers - - -def _get_endpoint() -> str: - return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") - - -class LangChainTracerV1(BaseTracer): - """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LangChain tracer.""" - super().__init__(**kwargs) - self.session: Optional[TracerSessionV1] = None - self._endpoint = _get_endpoint() - self._headers = get_headers() - - def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]: - session = self.session or self.load_default_session() - if not isinstance(session, TracerSessionV1): - raise ValueError( - "LangChainTracerV1 is not compatible with" - f" session of type {type(session)}" - ) - - if run.run_type == "llm": - if "prompts" in run.inputs: - prompts = run.inputs["prompts"] - elif "messages" in run.inputs: - prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]] - else: - raise ValueError("No prompts found in LLM run inputs") - return LLMRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - extra=run.extra, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - error=run.error, - prompts=prompts, - response=run.outputs if run.outputs else None, - ) - if run.run_type == "chain": - child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] - return ChainRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - inputs=run.inputs, - outputs=run.outputs, - error=run.error, - extra=run.extra, - child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], - child_chain_runs=[ - run for run in child_runs if isinstance(run, ChainRun) - ], - child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], - ) - if run.run_type == "tool": - child_runs = [self._convert_to_v1_run(run) for run in run.child_runs] - return ToolRun( - uuid=str(run.id) if run.id else None, - parent_uuid=str(run.parent_run_id) if run.parent_run_id else None, - start_time=run.start_time, - end_time=run.end_time, - execution_order=run.execution_order, - child_execution_order=run.child_execution_order, - serialized=run.serialized, - session_id=session.id, - action=str(run.serialized), - tool_input=run.inputs.get("input", ""), - output=None if run.outputs is None else run.outputs.get("output"), - error=run.error, - extra=run.extra, - child_chain_runs=[ - run for run in child_runs if isinstance(run, ChainRun) - ], - child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)], - child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)], - ) - raise ValueError(f"Unknown run type: {run.run_type}") - - def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - if isinstance(run, Run): - v1_run = self._convert_to_v1_run(run) - else: - v1_run = run - if isinstance(v1_run, LLMRun): - endpoint = f"{self._endpoint}/llm-runs" - elif isinstance(v1_run, ChainRun): - endpoint = f"{self._endpoint}/chain-runs" - else: - endpoint = f"{self._endpoint}/tool-runs" - - try: - response = requests.post( - endpoint, - data=v1_run.json(), - headers=self._headers, - ) - raise_for_status_with_text(response) - except Exception as e: - logger.warning(f"Failed to persist run: {e}") - - def _persist_session( - self, session_create: TracerSessionV1Base - ) -> Union[TracerSessionV1, TracerSession]: - """Persist a session.""" - try: - r = requests.post( - f"{self._endpoint}/sessions", - data=session_create.json(), - headers=self._headers, - ) - session = TracerSessionV1(id=r.json()["id"], **session_create.dict()) - except Exception as e: - logger.warning(f"Failed to create session, using default session: {e}") - session = TracerSessionV1(id=1, **session_create.dict()) - return session - - def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1: - """Load a session from the tracer.""" - try: - url = f"{self._endpoint}/sessions" - if session_name: - url += f"?name={session_name}" - r = requests.get(url, headers=self._headers) - - tracer_session = TracerSessionV1(**r.json()[0]) - except Exception as e: - session_type = "default" if not session_name else session_name - logger.warning( - f"Failed to load {session_type} session, using empty session: {e}" - ) - tracer_session = TracerSessionV1(id=1) - - self.session = tracer_session - return tracer_session - - def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]: - """Load a session with the given name from the tracer.""" - return self._load_session(session_name) - - def load_default_session(self) -> Union[TracerSessionV1, TracerSession]: - """Load the default tracing session and set it as the Tracer's session.""" - return self._load_session("default") +__all__ = ["get_headers", "LangChainTracerV1"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py b/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py index 6b2acd3cbd2..e7e29ba69cc 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/log_stream.py @@ -1,311 +1,9 @@ -from __future__ import annotations - -import math -import threading -from collections import defaultdict -from typing import ( - Any, - AsyncIterator, - Dict, - List, - Optional, - Sequence, - TypedDict, - Union, +from langchain_core.callbacks.tracers.log_stream import ( + LogEntry, + LogStreamCallbackHandler, + RunLog, + RunLogPatch, + RunState, ) -from uuid import UUID -import jsonpatch -from anyio import create_memory_object_stream - -from langchain.load.load import load -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import Run -from langchain.schema.output import ChatGenerationChunk, GenerationChunk - - -class LogEntry(TypedDict): - """A single entry in the run log.""" - - id: str - """ID of the sub-run.""" - name: str - """Name of the object being run.""" - type: str - """Type of the object being run, eg. prompt, chain, llm, etc.""" - tags: List[str] - """List of tags for the run.""" - metadata: Dict[str, Any] - """Key-value pairs of metadata for the run.""" - start_time: str - """ISO-8601 timestamp of when the run started.""" - - streamed_output_str: List[str] - """List of LLM tokens streamed by this run, if applicable.""" - final_output: Optional[Any] - """Final output of this run. - Only available after the run has finished successfully.""" - end_time: Optional[str] - """ISO-8601 timestamp of when the run ended. - Only available after the run has finished.""" - - -class RunState(TypedDict): - """State of the run.""" - - id: str - """ID of the run.""" - streamed_output: List[Any] - """List of output chunks streamed by Runnable.stream()""" - final_output: Optional[Any] - """Final output of the run, usually the result of aggregating (`+`) streamed_output. - Only available after the run has finished successfully.""" - - logs: Dict[str, LogEntry] - """Map of run names to sub-runs. If filters were supplied, this list will - contain only the runs that matched the filters.""" - - -class RunLogPatch: - """A patch to the run log.""" - - ops: List[Dict[str, Any]] - """List of jsonpatch operations, which describe how to create the run state - from an empty dict. This is the minimal representation of the log, designed to - be serialized as JSON and sent over the wire to reconstruct the log on the other - side. Reconstruction of the state can be done with any jsonpatch-compliant library, - see https://jsonpatch.com for more information.""" - - def __init__(self, *ops: Dict[str, Any]) -> None: - self.ops = list(ops) - - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: - if type(other) == RunLogPatch: - ops = self.ops + other.ops - state = jsonpatch.apply_patch(None, ops) - return RunLog(*ops, state=state) - - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - def __repr__(self) -> str: - from pprint import pformat - - # 1:-1 to get rid of the [] around the list - return f"RunLogPatch({pformat(self.ops)[1:-1]})" - - def __eq__(self, other: object) -> bool: - return isinstance(other, RunLogPatch) and self.ops == other.ops - - -class RunLog(RunLogPatch): - """A run log.""" - - state: RunState - """Current state of the log, obtained from applying all ops in sequence.""" - - def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: - super().__init__(*ops) - self.state = state - - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: - if type(other) == RunLogPatch: - ops = self.ops + other.ops - state = jsonpatch.apply_patch(self.state, other.ops) - return RunLog(*ops, state=state) - - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - def __repr__(self) -> str: - from pprint import pformat - - return f"RunLog({pformat(self.state)})" - - -class LogStreamCallbackHandler(BaseTracer): - """A tracer that streams run logs to a stream.""" - - def __init__( - self, - *, - auto_close: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - ) -> None: - super().__init__() - - self.auto_close = auto_close - self.include_names = include_names - self.include_types = include_types - self.include_tags = include_tags - self.exclude_names = exclude_names - self.exclude_types = exclude_types - self.exclude_tags = exclude_tags - - send_stream, receive_stream = create_memory_object_stream( - math.inf, item_type=RunLogPatch - ) - self.lock = threading.Lock() - self.send_stream = send_stream - self.receive_stream = receive_stream - self._key_map_by_run_id: Dict[UUID, str] = {} - self._counter_map_by_name: Dict[str, int] = defaultdict(int) - self.root_id: Optional[UUID] = None - - def __aiter__(self) -> AsyncIterator[RunLogPatch]: - return self.receive_stream.__aiter__() - - def include_run(self, run: Run) -> bool: - if run.id == self.root_id: - return False - - run_tags = run.tags or [] - - if ( - self.include_names is None - and self.include_types is None - and self.include_tags is None - ): - include = True - else: - include = False - - if self.include_names is not None: - include = include or run.name in self.include_names - if self.include_types is not None: - include = include or run.run_type in self.include_types - if self.include_tags is not None: - include = include or any(tag in self.include_tags for tag in run_tags) - - if self.exclude_names is not None: - include = include and run.name not in self.exclude_names - if self.exclude_types is not None: - include = include and run.run_type not in self.exclude_types - if self.exclude_tags is not None: - include = include and all(tag not in self.exclude_tags for tag in run_tags) - - return include - - def _persist_run(self, run: Run) -> None: - # This is a legacy method only called once for an entire run tree - # therefore not useful here - pass - - def _on_run_create(self, run: Run) -> None: - """Start a run.""" - if self.root_id is None: - self.root_id = run.id - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "", - "value": RunState( - id=str(run.id), - streamed_output=[], - final_output=None, - logs={}, - ), - } - ) - ) - - if not self.include_run(run): - return - - # Determine previous index, increment by 1 - with self.lock: - self._counter_map_by_name[run.name] += 1 - count = self._counter_map_by_name[run.name] - self._key_map_by_run_id[run.id] = ( - run.name if count == 1 else f"{run.name}:{count}" - ) - - # Add the run to the stream - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{self._key_map_by_run_id[run.id]}", - "value": LogEntry( - id=str(run.id), - name=run.name, - type=run.run_type, - tags=run.tags or [], - metadata=(run.extra or {}).get("metadata", {}), - start_time=run.start_time.isoformat(timespec="milliseconds"), - streamed_output_str=[], - final_output=None, - end_time=None, - ), - } - ) - ) - - def _on_run_update(self, run: Run) -> None: - """Finish a run.""" - try: - index = self._key_map_by_run_id.get(run.id) - - if index is None: - return - - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{index}/final_output", - # to undo the dumpd done by some runnables / tracer / etc - "value": load(run.outputs), - }, - { - "op": "add", - "path": f"/logs/{index}/end_time", - "value": run.end_time.isoformat(timespec="milliseconds") - if run.end_time is not None - else None, - }, - ) - ) - finally: - if run.id == self.root_id: - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "/final_output", - "value": load(run.outputs), - } - ) - ) - if self.auto_close: - self.send_stream.close() - - def _on_llm_new_token( - self, - run: Run, - token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], - ) -> None: - """Process new LLM token.""" - index = self._key_map_by_run_id.get(run.id) - - if index is None: - return - - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "add", - "path": f"/logs/{index}/streamed_output_str/-", - "value": token, - } - ) - ) +__all__ = ["LogEntry", "RunState", "RunLogPatch", "RunLog", "LogStreamCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py index 5a134e5946c..f57b31c938d 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/root_listeners.py @@ -1,54 +1,3 @@ -from typing import Callable, Optional, Union -from uuid import UUID +from langchain_core.callbacks.tracers.root_listeners import RootListenersTracer -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import Run -from langchain.schema.runnable.config import ( - RunnableConfig, - call_func_with_variable_args, -) - -Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - - -class RootListenersTracer(BaseTracer): - def __init__( - self, - *, - config: RunnableConfig, - on_start: Optional[Listener], - on_end: Optional[Listener], - on_error: Optional[Listener], - ) -> None: - super().__init__() - - self.config = config - self._arg_on_start = on_start - self._arg_on_end = on_end - self._arg_on_error = on_error - self.root_id: Optional[UUID] = None - - def _persist_run(self, run: Run) -> None: - # This is a legacy method only called once for an entire run tree - # therefore not useful here - pass - - def _on_run_create(self, run: Run) -> None: - if self.root_id is not None: - return - - self.root_id = run.id - - if self._arg_on_start is not None: - call_func_with_variable_args(self._arg_on_start, run, self.config) - - def _on_run_update(self, run: Run) -> None: - if run.id != self.root_id: - return - - if run.error is None: - if self._arg_on_end is not None: - call_func_with_variable_args(self._arg_on_end, run, self.config) - else: - if self._arg_on_error is not None: - call_func_with_variable_args(self._arg_on_error, run, self.config) +__all__ = ["RootListenersTracer"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py b/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py index 8087121a13d..1e872946631 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/run_collector.py @@ -1,52 +1,3 @@ -"""A tracer that collects all nested runs in a list.""" +from langchain_core.callbacks.tracers.run_collector import RunCollectorCallbackHandler -from typing import Any, List, Optional, Union -from uuid import UUID - -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import Run - - -class RunCollectorCallbackHandler(BaseTracer): - """ - A tracer that collects all nested runs in a list. - - This tracer is useful for inspection and evaluation purposes. - - Parameters - ---------- - example_id : Optional[Union[UUID, str]], default=None - The ID of the example being traced. It can be either a UUID or a string. - """ - - name: str = "run-collector_callback_handler" - - def __init__( - self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any - ) -> None: - """ - Initialize the RunCollectorCallbackHandler. - - Parameters - ---------- - example_id : Optional[Union[UUID, str]], default=None - The ID of the example being traced. It can be either a UUID or a string. - """ - super().__init__(**kwargs) - self.example_id = ( - UUID(example_id) if isinstance(example_id, str) else example_id - ) - self.traced_runs: List[Run] = [] - - def _persist_run(self, run: Run) -> None: - """ - Persist a run by adding it to the traced_runs list. - - Parameters - ---------- - run : Run - The run to be persisted. - """ - run_ = run.copy() - run_.reference_example_id = self.example_id - self.traced_runs.append(run_) +__all__ = ["RunCollectorCallbackHandler"] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/schemas.py b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py index 4db455be2ea..87e0f35b4a9 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/schemas.py @@ -1,140 +1,27 @@ -"""Schemas for tracers.""" -from __future__ import annotations - -import datetime -import warnings -from typing import Any, Dict, List, Optional, Type -from uuid import UUID - -from langsmith.schemas import RunBase as BaseRunV2 -from langsmith.schemas import RunTypeEnum as RunTypeEnumDep - -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import LLMResult - - -def RunTypeEnum() -> Type[RunTypeEnumDep]: - """RunTypeEnum.""" - warnings.warn( - "RunTypeEnum is deprecated. Please directly use a string instead" - " (e.g. 'llm', 'chain', 'tool').", - DeprecationWarning, - ) - return RunTypeEnumDep - - -class TracerSessionV1Base(BaseModel): - """Base class for TracerSessionV1.""" - - start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - name: Optional[str] = None - extra: Optional[Dict[str, Any]] = None - - -class TracerSessionV1Create(TracerSessionV1Base): - """Create class for TracerSessionV1.""" - - -class TracerSessionV1(TracerSessionV1Base): - """TracerSessionV1 schema.""" - - id: int - - -class TracerSessionBase(TracerSessionV1Base): - """Base class for TracerSession.""" - - tenant_id: UUID - - -class TracerSession(TracerSessionBase): - """TracerSessionV1 schema for the V2 API.""" - - id: UUID - - -class BaseRun(BaseModel): - """Base class for Run.""" - - uuid: str - parent_uuid: Optional[str] = None - start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - extra: Optional[Dict[str, Any]] = None - execution_order: int - child_execution_order: int - serialized: Dict[str, Any] - session_id: int - error: Optional[str] = None - - -class LLMRun(BaseRun): - """Class for LLMRun.""" - - prompts: List[str] - response: Optional[LLMResult] = None - - -class ChainRun(BaseRun): - """Class for ChainRun.""" - - inputs: Dict[str, Any] - outputs: Optional[Dict[str, Any]] = None - child_llm_runs: List[LLMRun] = Field(default_factory=list) - child_chain_runs: List[ChainRun] = Field(default_factory=list) - child_tool_runs: List[ToolRun] = Field(default_factory=list) - - -class ToolRun(BaseRun): - """Class for ToolRun.""" - - tool_input: str - output: Optional[str] = None - action: str - child_llm_runs: List[LLMRun] = Field(default_factory=list) - child_chain_runs: List[ChainRun] = Field(default_factory=list) - child_tool_runs: List[ToolRun] = Field(default_factory=list) - - -# Begin V2 API Schemas - - -class Run(BaseRunV2): - """Run schema for the V2 API in the Tracer.""" - - execution_order: int - child_execution_order: int - child_runs: List[Run] = Field(default_factory=list) - tags: Optional[List[str]] = Field(default_factory=list) - events: List[Dict[str, Any]] = Field(default_factory=list) - - @root_validator(pre=True) - def assign_name(cls, values: dict) -> dict: - """Assign name to the run.""" - if values.get("name") is None: - if "name" in values["serialized"]: - values["name"] = values["serialized"]["name"] - elif "id" in values["serialized"]: - values["name"] = values["serialized"]["id"][-1] - if values.get("events") is None: - values["events"] = [] - return values - - -ChainRun.update_forward_refs() -ToolRun.update_forward_refs() -Run.update_forward_refs() +from langchain_core.callbacks.tracers.schemas import ( + BaseRun, + ChainRun, + LLMRun, + Run, + RunTypeEnum, + ToolRun, + TracerSession, + TracerSessionBase, + TracerSessionV1, + TracerSessionV1Base, + TracerSessionV1Create, +) __all__ = [ - "BaseRun", - "ChainRun", - "LLMRun", - "Run", "RunTypeEnum", - "ToolRun", - "TracerSession", - "TracerSessionBase", - "TracerSessionV1", "TracerSessionV1Base", "TracerSessionV1Create", + "TracerSessionV1", + "TracerSessionBase", + "TracerSession", + "BaseRun", + "LLMRun", + "ChainRun", + "ToolRun", + "Run", ] diff --git a/libs/langchain/langchain/schema/callbacks/tracers/stdout.py b/libs/langchain/langchain/schema/callbacks/tracers/stdout.py index 564f419a5bf..1bb931a21e4 100644 --- a/libs/langchain/langchain/schema/callbacks/tracers/stdout.py +++ b/libs/langchain/langchain/schema/callbacks/tracers/stdout.py @@ -1,178 +1,13 @@ -import json -from typing import Any, Callable, List +from langchain_core.callbacks.tracers.stdout import ( + ConsoleCallbackHandler, + FunctionCallbackHandler, + elapsed, + try_json_stringify, +) -from langchain.schema.callbacks.tracers.base import BaseTracer -from langchain.schema.callbacks.tracers.schemas import Run -from langchain.utils.input import get_bolded_text, get_colored_text - - -def try_json_stringify(obj: Any, fallback: str) -> str: - """ - Try to stringify an object to JSON. - Args: - obj: Object to stringify. - fallback: Fallback string to return if the object cannot be stringified. - - Returns: - A JSON string if the object can be stringified, otherwise the fallback string. - - """ - try: - return json.dumps(obj, indent=2, ensure_ascii=False) - except Exception: - return fallback - - -def elapsed(run: Any) -> str: - """Get the elapsed time of a run. - - Args: - run: any object with a start_time and end_time attribute. - - Returns: - A string with the elapsed time in seconds or - milliseconds if time is less than a second. - - """ - elapsed_time = run.end_time - run.start_time - milliseconds = elapsed_time.total_seconds() * 1000 - if milliseconds < 1000: - return f"{milliseconds:.0f}ms" - return f"{(milliseconds / 1000):.2f}s" - - -class FunctionCallbackHandler(BaseTracer): - """Tracer that calls a function with a single str parameter.""" - - name: str = "function_callback_handler" - - def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: - super().__init__(**kwargs) - self.function_callback = function - - def _persist_run(self, run: Run) -> None: - pass - - def get_parents(self, run: Run) -> List[Run]: - parents = [] - current_run = run - while current_run.parent_run_id: - parent = self.run_map.get(str(current_run.parent_run_id)) - if parent: - parents.append(parent) - current_run = parent - else: - break - return parents - - def get_breadcrumbs(self, run: Run) -> str: - parents = self.get_parents(run)[::-1] - string = " > ".join( - f"{parent.execution_order}:{parent.run_type}:{parent.name}" - if i != len(parents) - 1 - else f"{parent.execution_order}:{parent.run_type}:{parent.name}" - for i, parent in enumerate(parents + [run]) - ) - return string - - # logging methods - def _on_chain_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/start]', color='green')} " - + get_bolded_text(f"[{crumbs}] Entering {run_type} run with input:\n") - + f"{try_json_stringify(run.inputs, '[inputs]')}" - ) - - def _on_chain_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/end]', color='blue')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting {run_type} run with output:\n" - ) - + f"{try_json_stringify(run.outputs, '[outputs]')}" - ) - - def _on_chain_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - run_type = run.run_type.capitalize() - self.function_callback( - f"{get_colored_text('[chain/error]', color='red')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] {run_type} run errored with error:\n" - ) - + f"{try_json_stringify(run.error, '[error]')}" - ) - - def _on_llm_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - inputs = ( - {"prompts": [p.strip() for p in run.inputs["prompts"]]} - if "prompts" in run.inputs - else run.inputs - ) - self.function_callback( - f"{get_colored_text('[llm/start]', color='green')} " - + get_bolded_text(f"[{crumbs}] Entering LLM run with input:\n") - + f"{try_json_stringify(inputs, '[inputs]')}" - ) - - def _on_llm_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[llm/end]', color='blue')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting LLM run with output:\n" - ) - + f"{try_json_stringify(run.outputs, '[response]')}" - ) - - def _on_llm_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[llm/error]', color='red')} " - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] LLM run errored with error:\n" - ) - + f"{try_json_stringify(run.error, '[error]')}" - ) - - def _on_tool_start(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f'{get_colored_text("[tool/start]", color="green")} ' - + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") - + f'"{run.inputs["input"].strip()}"' - ) - - def _on_tool_end(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - if run.outputs: - self.function_callback( - f'{get_colored_text("[tool/end]", color="blue")} ' - + get_bolded_text( - f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" - ) - + f'"{run.outputs["output"].strip()}"' - ) - - def _on_tool_error(self, run: Run) -> None: - crumbs = self.get_breadcrumbs(run) - self.function_callback( - f"{get_colored_text('[tool/error]', color='red')} " - + get_bolded_text(f"[{crumbs}] [{elapsed(run)}] ") - + f"Tool run errored with error:\n" - f"{run.error}" - ) - - -class ConsoleCallbackHandler(FunctionCallbackHandler): - """Tracer that prints to the console.""" - - name: str = "console_callback_handler" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(function=print, **kwargs) +__all__ = [ + "try_json_stringify", + "elapsed", + "FunctionCallbackHandler", + "ConsoleCallbackHandler", +] diff --git a/libs/langchain/langchain/schema/chat.py b/libs/langchain/langchain/schema/chat.py index f76194f4f85..0dfc5dd876c 100644 --- a/libs/langchain/langchain/schema/chat.py +++ b/libs/langchain/langchain/schema/chat.py @@ -1,13 +1,3 @@ -from typing import Sequence, TypedDict +from langchain_core.schema.chat import ChatSession -from langchain.schema import BaseMessage - - -class ChatSession(TypedDict, total=False): - """Chat Session represents a single - conversation, channel, or other group of messages.""" - - messages: Sequence[BaseMessage] - """The LangChain chat messages loaded from the source.""" - functions: Sequence[dict] - """The function calling specs for the messages.""" +__all__ = ["ChatSession"] diff --git a/libs/langchain/langchain/schema/chat_history.py b/libs/langchain/langchain/schema/chat_history.py index 1f74ed0cd25..0321f021672 100644 --- a/libs/langchain/langchain/schema/chat_history.py +++ b/libs/langchain/langchain/schema/chat_history.py @@ -1,67 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.chat_history import BaseChatMessageHistory -from abc import ABC, abstractmethod -from typing import List - -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage - - -class BaseChatMessageHistory(ABC): - """Abstract base class for storing chat message history. - - See `ChatMessageHistory` for default implementation. - - Example: - .. code-block:: python - - class FileChatMessageHistory(BaseChatMessageHistory): - storage_path: str - session_id: str - - @property - def messages(self): - with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: - messages = json.loads(f.read()) - return messages_from_dict(messages) - - def add_message(self, message: BaseMessage) -> None: - messages = self.messages.append(_message_to_dict(message)) - with open(os.path.join(storage_path, session_id), 'w') as f: - json.dump(f, messages) - - def clear(self): - with open(os.path.join(storage_path, session_id), 'w') as f: - f.write("[]") - """ - - messages: List[BaseMessage] - """A list of Messages stored in-memory.""" - - def add_user_message(self, message: str) -> None: - """Convenience method for adding a human message string to the store. - - Args: - message: The string contents of a human message. - """ - self.add_message(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - """Convenience method for adding an AI message string to the store. - - Args: - message: The string contents of an AI message. - """ - self.add_message(AIMessage(content=message)) - - @abstractmethod - def add_message(self, message: BaseMessage) -> None: - """Add a Message object to the store. - - Args: - message: A BaseMessage object to store. - """ - raise NotImplementedError() - - @abstractmethod - def clear(self) -> None: - """Remove all messages from the store""" +__all__ = ["BaseChatMessageHistory"] diff --git a/libs/langchain/langchain/schema/document.py b/libs/langchain/langchain/schema/document.py index c552ebda7db..97ee9844f58 100644 --- a/libs/langchain/langchain/schema/document.py +++ b/libs/langchain/langchain/schema/document.py @@ -1,91 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.document import BaseDocumentTransformer, Document -import asyncio -from abc import ABC, abstractmethod -from functools import partial -from typing import Any, Literal, Sequence - -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import Field - - -class Document(Serializable): - """Class for storing a piece of text and associated metadata.""" - - page_content: str - """String text.""" - metadata: dict = Field(default_factory=dict) - """Arbitrary metadata about the page content (e.g., source, relationships to other - documents, etc.). - """ - type: Literal["Document"] = "Document" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - -class BaseDocumentTransformer(ABC): - """Abstract base class for document transformation systems. - - A document transformation system takes a sequence of Documents and returns a - sequence of transformed Documents. - - Example: - .. code-block:: python - - class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): - embeddings: Embeddings - similarity_fn: Callable = cosine_similarity - similarity_threshold: float = 0.95 - - class Config: - arbitrary_types_allowed = True - - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - stateful_documents = get_stateful_documents(documents) - embedded_documents = _get_embeddings_from_stateful_docs( - self.embeddings, stateful_documents - ) - included_idxs = _filter_similar_embeddings( - embedded_documents, self.similarity_fn, self.similarity_threshold - ) - return [stateful_documents[i] for i in sorted(included_idxs)] - - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - raise NotImplementedError - - """ # noqa: E501 - - @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Transform a list of documents. - - Args: - documents: A sequence of Documents to be transformed. - - Returns: - A list of transformed Documents. - """ - - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Asynchronously transform a list of documents. - - Args: - documents: A sequence of Documents to be transformed. - - Returns: - A list of transformed Documents. - """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents - ) +__all__ = ["Document", "BaseDocumentTransformer"] diff --git a/libs/langchain/langchain/schema/embeddings.py b/libs/langchain/langchain/schema/embeddings.py index c08a279750b..a5ada340d9d 100644 --- a/libs/langchain/langchain/schema/embeddings.py +++ b/libs/langchain/langchain/schema/embeddings.py @@ -1,27 +1,3 @@ -import asyncio -from abc import ABC, abstractmethod -from typing import List +from langchain_core.schema.embeddings import Embeddings - -class Embeddings(ABC): - """Interface for embedding models.""" - - @abstractmethod - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed search docs.""" - - @abstractmethod - def embed_query(self, text: str) -> List[float]: - """Embed query text.""" - - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - """Asynchronous Embed search docs.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_documents, texts - ) - - async def aembed_query(self, text: str) -> List[float]: - """Asynchronous Embed query text.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_query, text - ) +__all__ = ["Embeddings"] diff --git a/libs/langchain/langchain/schema/exceptions.py b/libs/langchain/langchain/schema/exceptions.py index 27ed0d07dc1..be1ad0b900c 100644 --- a/libs/langchain/langchain/schema/exceptions.py +++ b/libs/langchain/langchain/schema/exceptions.py @@ -1,2 +1,3 @@ -class LangChainException(Exception): - """General LangChain exception.""" +from langchain_core.schema.exceptions import LangChainException + +__all__ = ["LangChainException"] diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index c4e8e5169de..4f7ea44926a 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -1,291 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.language_model import BaseLanguageModel, get_tokenizer -from abc import ABC, abstractmethod -from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Any, - List, - Optional, - Sequence, - Set, - TypeVar, - Union, -) - -from typing_extensions import TypeAlias - -from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string -from langchain.schema.output import LLMResult -from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import RunnableSerializable -from langchain.utils import get_pydantic_field_names - -if TYPE_CHECKING: - from langchain.callbacks.manager import Callbacks - - -@lru_cache(maxsize=None) # Cache the tokenizer -def get_tokenizer() -> Any: - try: - from transformers import GPT2TokenizerFast - except ImportError: - raise ImportError( - "Could not import transformers python package. " - "This is needed in order to calculate get_token_ids. " - "Please install it with `pip install transformers`." - ) - # create a GPT-2 tokenizer instance - return GPT2TokenizerFast.from_pretrained("gpt2") - - -def _get_token_ids_default_method(text: str) -> List[int]: - """Encode the text into token IDs.""" - # get the cached tokenizer - tokenizer = get_tokenizer() - - # tokenize the text using the GPT-2 tokenizer - return tokenizer.encode(text) - - -LanguageModelInput = Union[PromptValue, str, List[BaseMessage]] -LanguageModelOutput = TypeVar("LanguageModelOutput") - - -class BaseLanguageModel( - RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC -): - """Abstract base class for interfacing with language models. - - All language model wrappers inherit from BaseLanguageModel. - - Exposes three main methods: - - generate_prompt: generate language model outputs for a sequence of prompt - values. A prompt value is a model input that can be converted to any language - model input format (string or messages). - - predict: pass in a single string to a language model and return a string - prediction. - - predict_messages: pass in a sequence of BaseMessages (corresponding to a single - model call) to a language model and return a BaseMessage prediction. - - Each of these has an equivalent asynchronous method. - """ - - @property - def InputType(self) -> TypeAlias: - """Get the input type for this runnable.""" - from langchain.prompts.base import StringPromptValue - from langchain.prompts.chat import ChatPromptValueConcrete - - # This is a version of LanguageModelInput which replaces the abstract - # base class BaseMessage with a union of its subclasses, which makes - # for a much better schema. - return Union[ - str, - Union[StringPromptValue, ChatPromptValueConcrete], - List[AnyMessage], - ] - - @abstractmethod - def generate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - """Pass a sequence of prompts to the model and return model generations. - - This method should make use of batched calls for models that expose a batched - API. - - Use this method when you want to: - 1. take advantage of batched calls, - 2. need more output from the model than just the top generated value, - 3. are building chains that are agnostic to the underlying language model - type (e.g., pure text completion models vs chat models). - - Args: - prompts: List of PromptValues. A PromptValue is an object that can be - converted to match the format of any language model (string for pure - text generation models and BaseMessages for chat models). - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - callbacks: Callbacks to pass through. Used for executing additional - functionality, such as logging or streaming, throughout generation. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - An LLMResult, which contains a list of candidate Generations for each input - prompt and additional model provider-specific output. - """ - - @abstractmethod - async def agenerate_prompt( - self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - """Asynchronously pass a sequence of prompts and return model generations. - - This method should make use of batched calls for models that expose a batched - API. - - Use this method when you want to: - 1. take advantage of batched calls, - 2. need more output from the model than just the top generated value, - 3. are building chains that are agnostic to the underlying language model - type (e.g., pure text completion models vs chat models). - - Args: - prompts: List of PromptValues. A PromptValue is an object that can be - converted to match the format of any language model (string for pure - text generation models and BaseMessages for chat models). - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - callbacks: Callbacks to pass through. Used for executing additional - functionality, such as logging or streaming, throughout generation. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - An LLMResult, which contains a list of candidate Generations for each input - prompt and additional model provider-specific output. - """ - - @abstractmethod - def predict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - """Pass a single string input to the model and return a string prediction. - - Use this method when passing in raw text. If you want to pass in specific - types of chat messages, use predict_messages. - - Args: - text: String input to pass to the model. - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - Top model prediction as a string. - """ - - @abstractmethod - def predict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - """Pass a message sequence to the model and return a message prediction. - - Use this method when passing in chat messages. If you want to pass in raw text, - use predict. - - Args: - messages: A sequence of chat messages corresponding to a single model input. - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - Top model prediction as a message. - """ - - @abstractmethod - async def apredict( - self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any - ) -> str: - """Asynchronously pass a string to the model and return a string prediction. - - Use this method when calling pure text generation models and only the top - candidate generation is needed. - - Args: - text: String input to pass to the model. - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - Top model prediction as a string. - """ - - @abstractmethod - async def apredict_messages( - self, - messages: List[BaseMessage], - *, - stop: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> BaseMessage: - """Asynchronously pass messages to the model and return a message prediction. - - Use this method when calling chat models and only the top - candidate generation is needed. - - Args: - messages: A sequence of chat messages corresponding to a single model input. - stop: Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. - **kwargs: Arbitrary additional keyword arguments. These are usually passed - to the model provider API call. - - Returns: - Top model prediction as a message. - """ - - def get_token_ids(self, text: str) -> List[int]: - """Return the ordered ids of the tokens in a text. - - Args: - text: The string input to tokenize. - - Returns: - A list of ids corresponding to the tokens in the text, in order they occur - in the text. - """ - return _get_token_ids_default_method(text) - - def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text. - - Useful for checking if an input will fit in a model's context window. - - Args: - text: The string input to tokenize. - - Returns: - The integer number of tokens in the text. - """ - return len(self.get_token_ids(text)) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - """Get the number of tokens in the messages. - - Useful for checking if an input will fit in a model's context window. - - Args: - messages: The message inputs to tokenize. - - Returns: - The sum of the number of tokens across the messages. - """ - return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) - - @classmethod - def _all_required_field_names(cls) -> Set: - """DEPRECATED: Kept for backwards compatibility. - - Use get_pydantic_field_names. - """ - return get_pydantic_field_names(cls) +__all__ = ["get_tokenizer", "BaseLanguageModel"] diff --git a/libs/langchain/langchain/schema/memory.py b/libs/langchain/langchain/schema/memory.py index cd4d572f985..325840af67f 100644 --- a/libs/langchain/langchain/schema/memory.py +++ b/libs/langchain/langchain/schema/memory.py @@ -1,59 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.memory import BaseMemory -from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from langchain.load.serializable import Serializable - - -class BaseMemory(Serializable, ABC): - """Abstract base class for memory in Chains. - - Memory refers to state in Chains. Memory can be used to store information about - past executions of a Chain and inject that information into the inputs of - future executions of the Chain. For example, for conversational Chains Memory - can be used to store conversations and automatically add them to future model - prompts so that the model has the necessary context to respond coherently to - the latest input. - - Example: - .. code-block:: python - - class SimpleMemory(BaseMemory): - memories: Dict[str, Any] = dict() - - @property - def memory_variables(self) -> List[str]: - return list(self.memories.keys()) - - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return self.memories - - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - pass - - def clear(self) -> None: - pass - """ # noqa: E501 - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - @abstractmethod - def memory_variables(self) -> List[str]: - """The string keys this memory class will add to chain inputs.""" - - @abstractmethod - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Return key-value pairs given the text input to the chain.""" - - @abstractmethod - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - """Save the context of this chain run to memory.""" - - @abstractmethod - def clear(self) -> None: - """Clear memory contents.""" +__all__ = ["BaseMemory"] diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index af6ef181902..8d23953bb3a 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -1,415 +1,41 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union - -from typing_extensions import Literal - -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import Extra, Field - -if TYPE_CHECKING: - from langchain.prompts.chat import ChatPromptTemplate - - -def get_buffer_string( - messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: - """Convert sequence of Messages to strings and concatenate them into one string. - - Args: - messages: Messages to be converted to strings. - human_prefix: The prefix to prepend to contents of HumanMessages. - ai_prefix: THe prefix to prepend to contents of AIMessages. - - Returns: - A single string concatenation of all input messages. - - Example: - .. code-block:: python - - from langchain.schema import AIMessage, HumanMessage - - messages = [ - HumanMessage(content="Hi, how are you?"), - AIMessage(content="Good, how are you?"), - ] - get_buffer_string(messages) - # -> "Human: Hi, how are you?\nAI: Good, how are you?" - """ - string_messages = [] - for m in messages: - if isinstance(m, HumanMessage): - role = human_prefix - elif isinstance(m, AIMessage): - role = ai_prefix - elif isinstance(m, SystemMessage): - role = "System" - elif isinstance(m, FunctionMessage): - role = "Function" - elif isinstance(m, ChatMessage): - role = m.role - else: - raise ValueError(f"Got unsupported message type: {m}") - message = f"{role}: {m.content}" - if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: - message += f"{m.additional_kwargs['function_call']}" - string_messages.append(message) - - return "\n".join(string_messages) - - -class BaseMessage(Serializable): - """The base abstract Message class. - - Messages are the inputs and outputs of ChatModels. - """ - - content: Union[str, List[Union[str, Dict]]] - """The string contents of the message.""" - - additional_kwargs: dict = Field(default_factory=dict) - """Any additional information.""" - - type: str - - class Config: - extra = Extra.allow - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - def __add__(self, other: Any) -> ChatPromptTemplate: - from langchain.prompts.chat import ChatPromptTemplate - - prompt = ChatPromptTemplate(messages=[self]) - return prompt + other - - -def merge_content( - first_content: Union[str, List[Union[str, Dict]]], - second_content: Union[str, List[Union[str, Dict]]], -) -> Union[str, List[Union[str, Dict]]]: - # If first chunk is a string - if isinstance(first_content, str): - # If the second chunk is also a string, then merge them naively - if isinstance(second_content, str): - return first_content + second_content - # If the second chunk is a list, add the first chunk to the start of the list - else: - return_list: List[Union[str, Dict]] = [first_content] - return return_list + second_content - # If both are lists, merge them naively - elif isinstance(second_content, List): - return first_content + second_content - # If the first content is a list, and the second content is a string - else: - # If the last element of the first content is a string - # Add the second content to the last element - if isinstance(first_content[-1], str): - return first_content[:-1] + [first_content[-1] + second_content] - else: - # Otherwise, add the second content as a new element of the list - return first_content + [second_content] - - -class BaseMessageChunk(BaseMessage): - """A Message chunk, which can be concatenated with other Message chunks.""" - - def _merge_kwargs_dict( - self, left: Dict[str, Any], right: Dict[str, Any] - ) -> Dict[str, Any]: - """Merge additional_kwargs from another BaseMessageChunk into this one.""" - merged = left.copy() - for k, v in right.items(): - if k not in merged: - merged[k] = v - elif type(merged[k]) != type(v): - raise ValueError( - f'additional_kwargs["{k}"] already exists in this message,' - " but with a different type." - ) - elif isinstance(merged[k], str): - merged[k] += v - elif isinstance(merged[k], dict): - merged[k] = self._merge_kwargs_dict(merged[k], v) - else: - raise ValueError( - f"Additional kwargs key {k} already exists in this message." - ) - return merged - - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore - if isinstance(other, BaseMessageChunk): - # If both are (subclasses of) BaseMessageChunk, - # concat into a single BaseMessageChunk - - if isinstance(self, ChatMessageChunk): - return self.__class__( - role=self.role, - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - return self.__class__( - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - else: - raise TypeError( - 'unsupported operand type(s) for +: "' - f"{self.__class__.__name__}" - f'" and "{other.__class__.__name__}"' - ) - - -class HumanMessage(BaseMessage): - """A Message from a human.""" - - example: bool = False - """Whether this Message is being passed in to the model as part of an example - conversation. - """ - - type: Literal["human"] = "human" - - -HumanMessage.update_forward_refs() - - -class HumanMessageChunk(HumanMessage, BaseMessageChunk): - """A Human Message chunk.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501 - - -class AIMessage(BaseMessage): - """A Message from an AI.""" - - example: bool = False - """Whether this Message is being passed in to the model as part of an example - conversation. - """ - - type: Literal["ai"] = "ai" - - -AIMessage.update_forward_refs() - - -class AIMessageChunk(AIMessage, BaseMessageChunk): - """A Message chunk from an AI.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 - - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore - if isinstance(other, AIMessageChunk): - if self.example != other.example: - raise ValueError( - "Cannot concatenate AIMessageChunks with different example values." - ) - - return self.__class__( - example=self.example, - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - - return super().__add__(other) - - -class SystemMessage(BaseMessage): - """A Message for priming AI behavior, usually passed in as the first of a sequence - of input messages. - """ - - type: Literal["system"] = "system" - - -SystemMessage.update_forward_refs() - - -class SystemMessageChunk(SystemMessage, BaseMessageChunk): - """A System Message chunk.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501 - - -class FunctionMessage(BaseMessage): - """A Message for passing the result of executing a function back to a model.""" - - name: str - """The name of the function that was executed.""" - - type: Literal["function"] = "function" - - -FunctionMessage.update_forward_refs() - - -class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): - """A Function Message chunk.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment] - - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore - if isinstance(other, FunctionMessageChunk): - if self.name != other.name: - raise ValueError( - "Cannot concatenate FunctionMessageChunks with different names." - ) - - return self.__class__( - name=self.name, - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - - return super().__add__(other) - - -class ToolMessage(BaseMessage): - """A Message for passing the result of executing a tool back to a model.""" - - tool_call_id: str - """Tool call that this message is responding to.""" - - type: Literal["tool"] = "tool" - - -ToolMessage.update_forward_refs() - - -class ToolMessageChunk(ToolMessage, BaseMessageChunk): - """A Tool Message chunk.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] - - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore - if isinstance(other, ToolMessageChunk): - if self.tool_call_id != other.tool_call_id: - raise ValueError( - "Cannot concatenate ToolMessageChunks with different names." - ) - - return self.__class__( - tool_call_id=self.tool_call_id, - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - - return super().__add__(other) - - -class ChatMessage(BaseMessage): - """A Message that can be assigned an arbitrary speaker (i.e. role).""" - - role: str - """The speaker / role of the Message.""" - - type: Literal["chat"] = "chat" - - -ChatMessage.update_forward_refs() - - -class ChatMessageChunk(ChatMessage, BaseMessageChunk): - """A Chat Message chunk.""" - - # Ignoring mypy re-assignment here since we're overriding the value - # to make sure that the chunk variant can be discriminated from the - # non-chunk variant. - type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore - - def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore - if isinstance(other, ChatMessageChunk): - if self.role != other.role: - raise ValueError( - "Cannot concatenate ChatMessageChunks with different roles." - ) - - return self.__class__( - role=self.role, - content=merge_content(self.content, other.content), - additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), - ) - - return super().__add__(other) - - -AnyMessage = Union[ - AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage +from langchain_core.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolMessage, + ToolMessageChunk, + get_buffer_string, + merge_content, + messages_from_dict, + messages_to_dict, +) + +__all__ = [ + "get_buffer_string", + "BaseMessage", + "merge_content", + "BaseMessageChunk", + "HumanMessage", + "HumanMessageChunk", + "AIMessage", + "AIMessageChunk", + "SystemMessage", + "SystemMessageChunk", + "FunctionMessage", + "FunctionMessageChunk", + "ToolMessage", + "ToolMessageChunk", + "ChatMessage", + "ChatMessageChunk", + "messages_to_dict", + "messages_from_dict", ] - - -def _message_to_dict(message: BaseMessage) -> dict: - return {"type": message.type, "data": message.dict()} - - -def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: - """Convert a sequence of Messages to a list of dictionaries. - - Args: - messages: Sequence of messages (as BaseMessages) to convert. - - Returns: - List of messages as dicts. - """ - return [_message_to_dict(m) for m in messages] - - -def _message_from_dict(message: dict) -> BaseMessage: - _type = message["type"] - if _type == "human": - return HumanMessage(**message["data"]) - elif _type == "ai": - return AIMessage(**message["data"]) - elif _type == "system": - return SystemMessage(**message["data"]) - elif _type == "chat": - return ChatMessage(**message["data"]) - elif _type == "function": - return FunctionMessage(**message["data"]) - elif _type == "tool": - return ToolMessage(**message["data"]) - else: - raise ValueError(f"Got unexpected message type: {_type}") - - -def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: - """Convert a sequence of messages from dicts to Message objects. - - Args: - messages: Sequence of messages (as dicts) to convert. - - Returns: - List of messages (BaseMessages). - """ - return [_message_from_dict(m) for m in messages] diff --git a/libs/langchain/langchain/schema/output.py b/libs/langchain/langchain/schema/output.py index b6bb22aa02a..7ed8ecc0dcf 100644 --- a/libs/langchain/langchain/schema/output.py +++ b/libs/langchain/langchain/schema/output.py @@ -1,175 +1,19 @@ -from __future__ import annotations +from langchain_core.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + Generation, + GenerationChunk, + LLMResult, + RunInfo, +) -from copy import deepcopy -from typing import Any, Dict, List, Literal, Optional -from uuid import UUID - -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.messages import BaseMessage, BaseMessageChunk - - -class Generation(Serializable): - """A single text generation output.""" - - text: str - """Generated text output.""" - - generation_info: Optional[Dict[str, Any]] = None - """Raw response from the provider. May include things like the - reason for finishing or token log probabilities. - """ - type: Literal["Generation"] = "Generation" - """Type is used exclusively for serialization purposes.""" - # TODO: add log probs as separate attribute - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - -class GenerationChunk(Generation): - """A Generation chunk, which can be concatenated with other Generation chunks.""" - - def __add__(self, other: GenerationChunk) -> GenerationChunk: - if isinstance(other, GenerationChunk): - generation_info = ( - {**(self.generation_info or {}), **(other.generation_info or {})} - if self.generation_info is not None or other.generation_info is not None - else None - ) - return GenerationChunk( - text=self.text + other.text, - generation_info=generation_info, - ) - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - -class ChatGeneration(Generation): - """A single chat generation output.""" - - text: str = "" - """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" - message: BaseMessage - """The message output by the chat model.""" - # Override type to be ChatGeneration, ignore mypy error as this is intentional - type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment] - """Type is used exclusively for serialization purposes.""" - - @root_validator - def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Set the text attribute to be the contents of the message.""" - try: - values["text"] = values["message"].content - except (KeyError, AttributeError) as e: - raise ValueError("Error while initializing ChatGeneration") from e - return values - - -class ChatGenerationChunk(ChatGeneration): - """A ChatGeneration chunk, which can be concatenated with other - ChatGeneration chunks. - - Attributes: - message: The message chunk output by the chat model. - """ - - message: BaseMessageChunk - # Override type to be ChatGeneration, ignore mypy error as this is intentional - type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501 - """Type is used exclusively for serialization purposes.""" - - def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: - if isinstance(other, ChatGenerationChunk): - generation_info = ( - {**(self.generation_info or {}), **(other.generation_info or {})} - if self.generation_info is not None or other.generation_info is not None - else None - ) - return ChatGenerationChunk( - message=self.message + other.message, - generation_info=generation_info, - ) - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" - ) - - -class RunInfo(BaseModel): - """Class that contains metadata for a single execution of a Chain or model.""" - - run_id: UUID - """A unique identifier for the model or chain run.""" - - -class ChatResult(BaseModel): - """Class that contains all results for a single chat model call.""" - - generations: List[ChatGeneration] - """List of the chat generations. This is a List because an input can have multiple - candidate generations. - """ - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" - - -class LLMResult(BaseModel): - """Class that contains all results for a batched LLM call.""" - - generations: List[List[Generation]] - """List of generated outputs. This is a List[List[]] because - each input could have multiple candidate generations.""" - llm_output: Optional[dict] = None - """Arbitrary LLM provider-specific output.""" - run: Optional[List[RunInfo]] = None - """List of metadata info for model call for each input.""" - - def flatten(self) -> List[LLMResult]: - """Flatten generations into a single list. - - Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult - contains only a single Generation. If token usage information is available, - it is kept only for the LLMResult corresponding to the top-choice - Generation, to avoid over-counting of token usage downstream. - - Returns: - List of LLMResults where each returned LLMResult contains a single - Generation. - """ - llm_results = [] - for i, gen_list in enumerate(self.generations): - # Avoid double counting tokens in OpenAICallback - if i == 0: - llm_results.append( - LLMResult( - generations=[gen_list], - llm_output=self.llm_output, - ) - ) - else: - if self.llm_output is not None: - llm_output = deepcopy(self.llm_output) - llm_output["token_usage"] = dict() - else: - llm_output = None - llm_results.append( - LLMResult( - generations=[gen_list], - llm_output=llm_output, - ) - ) - return llm_results - - def __eq__(self, other: object) -> bool: - """Check for LLMResult equality by ignoring any metadata related to runs.""" - if not isinstance(other, LLMResult): - return NotImplemented - return ( - self.generations == other.generations - and self.llm_output == other.llm_output - ) +__all__ = [ + "Generation", + "GenerationChunk", + "ChatGeneration", + "ChatGenerationChunk", + "RunInfo", + "ChatResult", + "LLMResult", +] diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index bdc47fea4cf..3702e9ea168 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -1,475 +1,19 @@ -from __future__ import annotations - -import asyncio -import functools -from abc import ABC, abstractmethod -from typing import ( - Any, - AsyncIterator, - Dict, - Generic, - Iterator, - List, - Optional, - Type, - TypeVar, - Union, +from langchain_core.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + BaseGenerationOutputParser, + BaseLLMOutputParser, + BaseOutputParser, + BaseTransformOutputParser, + OutputParserException, + StrOutputParser, ) -from typing_extensions import get_args - -from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk -from langchain.schema.output import ( - ChatGeneration, - ChatGenerationChunk, - Generation, - GenerationChunk, -) -from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import RunnableConfig, RunnableSerializable - -T = TypeVar("T") - - -class BaseLLMOutputParser(Generic[T], ABC): - """Abstract base class for parsing the outputs of a model.""" - - @abstractmethod - def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: - """Parse a list of candidate model Generations into a specific format. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - - async def aparse_result( - self, result: List[Generation], *, partial: bool = False - ) -> T: - """Parse a list of candidate model Generations into a specific format. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - return await asyncio.get_running_loop().run_in_executor( - None, self.parse_result, result - ) - - -class BaseGenerationOutputParser( - BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] -): - """Base class to parse the output of an LLM call.""" - - @property - def InputType(self) -> Any: - return Union[str, AnyMessage] - - @property - def OutputType(self) -> Type[T]: - # even though mypy complains this isn't valid, - # it is good enough for pydantic to build the schema from - return T # type: ignore[misc] - - def invoke( - self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None - ) -> T: - if isinstance(input, BaseMessage): - return self._call_with_config( - lambda inner_input: self.parse_result( - [ChatGeneration(message=inner_input)] - ), - input, - config, - run_type="parser", - ) - else: - return self._call_with_config( - lambda inner_input: self.parse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) - - async def ainvoke( - self, - input: str | BaseMessage, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> T: - if isinstance(input, BaseMessage): - return await self._acall_with_config( - lambda inner_input: self.aparse_result( - [ChatGeneration(message=inner_input)] - ), - input, - config, - run_type="parser", - ) - else: - return await self._acall_with_config( - lambda inner_input: self.aparse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) - - -class BaseOutputParser( - BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] -): - """Base class to parse the output of an LLM call. - - Output parsers help structure language model responses. - - Example: - .. code-block:: python - - class BooleanOutputParser(BaseOutputParser[bool]): - true_val: str = "YES" - false_val: str = "NO" - - def parse(self, text: str) -> bool: - cleaned_text = text.strip().upper() - if cleaned_text not in (self.true_val.upper(), self.false_val.upper()): - raise OutputParserException( - f"BooleanOutputParser expected output value to either be " - f"{self.true_val} or {self.false_val} (case-insensitive). " - f"Received {cleaned_text}." - ) - return cleaned_text == self.true_val.upper() - - @property - def _type(self) -> str: - return "boolean_output_parser" - """ # noqa: E501 - - @property - def InputType(self) -> Any: - return Union[str, AnyMessage] - - @property - def OutputType(self) -> Type[T]: - for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] - type_args = get_args(cls) - if type_args and len(type_args) == 1: - return type_args[0] - - raise TypeError( - f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. " - "Override the OutputType property to specify the output type." - ) - - def invoke( - self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None - ) -> T: - if isinstance(input, BaseMessage): - return self._call_with_config( - lambda inner_input: self.parse_result( - [ChatGeneration(message=inner_input)] - ), - input, - config, - run_type="parser", - ) - else: - return self._call_with_config( - lambda inner_input: self.parse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) - - async def ainvoke( - self, - input: str | BaseMessage, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> T: - if isinstance(input, BaseMessage): - return await self._acall_with_config( - lambda inner_input: self.aparse_result( - [ChatGeneration(message=inner_input)] - ), - input, - config, - run_type="parser", - ) - else: - return await self._acall_with_config( - lambda inner_input: self.aparse_result([Generation(text=inner_input)]), - input, - config, - run_type="parser", - ) - - def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: - """Parse a list of candidate model Generations into a specific format. - - The return value is parsed from only the first Generation in the result, which - is assumed to be the highest-likelihood Generation. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - return self.parse(result[0].text) - - @abstractmethod - def parse(self, text: str) -> T: - """Parse a single string model output into some structure. - - Args: - text: String output of a language model. - - Returns: - Structured output. - """ - - async def aparse_result( - self, result: List[Generation], *, partial: bool = False - ) -> T: - """Parse a list of candidate model Generations into a specific format. - - The return value is parsed from only the first Generation in the result, which - is assumed to be the highest-likelihood Generation. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - return await asyncio.get_running_loop().run_in_executor( - None, functools.partial(self.parse_result, partial=partial), result - ) - - async def aparse(self, text: str) -> T: - """Parse a single string model output into some structure. - - Args: - text: String output of a language model. - - Returns: - Structured output. - """ - return await asyncio.get_running_loop().run_in_executor(None, self.parse, text) - - # TODO: rename 'completion' -> 'text'. - def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: - """Parse the output of an LLM call with the input prompt for context. - - The prompt is largely provided in the event the OutputParser wants - to retry or fix the output in some way, and needs information from - the prompt to do so. - - Args: - completion: String output of a language model. - prompt: Input PromptValue. - - Returns: - Structured output - """ - return self.parse(completion) - - def get_format_instructions(self) -> str: - """Instructions on how the LLM output should be formatted.""" - raise NotImplementedError - - @property - def _type(self) -> str: - """Return the output parser type for serialization.""" - raise NotImplementedError( - f"_type property is not implemented in class {self.__class__.__name__}." - " This is required for serialization." - ) - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict(**kwargs) - try: - output_parser_dict["_type"] = self._type - except NotImplementedError: - pass - return output_parser_dict - - -class BaseTransformOutputParser(BaseOutputParser[T]): - """Base class for an output parser that can handle streaming input.""" - - def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]: - for chunk in input: - if isinstance(chunk, BaseMessage): - yield self.parse_result([ChatGeneration(message=chunk)]) - else: - yield self.parse_result([Generation(text=chunk)]) - - async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] - ) -> AsyncIterator[T]: - async for chunk in input: - if isinstance(chunk, BaseMessage): - yield self.parse_result([ChatGeneration(message=chunk)]) - else: - yield self.parse_result([Generation(text=chunk)]) - - def transform( - self, - input: Iterator[Union[str, BaseMessage]], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[T]: - yield from self._transform_stream_with_config( - input, self._transform, config, run_type="parser" - ) - - async def atransform( - self, - input: AsyncIterator[Union[str, BaseMessage]], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[T]: - async for chunk in self._atransform_stream_with_config( - input, self._atransform, config, run_type="parser" - ): - yield chunk - - -class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): - """Base class for an output parser that can handle streaming input.""" - - diff: bool = False - """In streaming mode, whether to yield diffs between the previous and current - parsed output, or just the current parsed output. - """ - - def _diff(self, prev: Optional[T], next: T) -> T: - """Convert parsed outputs into a diff format. The semantics of this are - up to the output parser.""" - raise NotImplementedError() - - def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: - prev_parsed = None - acc_gen = None - for chunk in input: - if isinstance(chunk, BaseMessageChunk): - chunk_gen: Generation = ChatGenerationChunk(message=chunk) - elif isinstance(chunk, BaseMessage): - chunk_gen = ChatGenerationChunk( - message=BaseMessageChunk(**chunk.dict()) - ) - else: - chunk_gen = GenerationChunk(text=chunk) - - if acc_gen is None: - acc_gen = chunk_gen - else: - acc_gen += chunk_gen - - parsed = self.parse_result([acc_gen], partial=True) - if parsed is not None and parsed != prev_parsed: - if self.diff: - yield self._diff(prev_parsed, parsed) - else: - yield parsed - prev_parsed = parsed - - async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] - ) -> AsyncIterator[T]: - prev_parsed = None - acc_gen = None - async for chunk in input: - if isinstance(chunk, BaseMessageChunk): - chunk_gen: Generation = ChatGenerationChunk(message=chunk) - elif isinstance(chunk, BaseMessage): - chunk_gen = ChatGenerationChunk( - message=BaseMessageChunk(**chunk.dict()) - ) - else: - chunk_gen = GenerationChunk(text=chunk) - - if acc_gen is None: - acc_gen = chunk_gen - else: - acc_gen += chunk_gen - - parsed = self.parse_result([acc_gen], partial=True) - if parsed is not None and parsed != prev_parsed: - if self.diff: - yield self._diff(prev_parsed, parsed) - else: - yield parsed - prev_parsed = parsed - - -class StrOutputParser(BaseTransformOutputParser[str]): - """OutputParser that parses LLMResult into the top likely string.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - @property - def _type(self) -> str: - """Return the output parser type for serialization.""" - return "default" - - def parse(self, text: str) -> str: - """Returns the input text with no changes.""" - return text - - -# TODO: Deprecate -NoOpOutputParser = StrOutputParser - - -class OutputParserException(ValueError): - """Exception that output parsers should raise to signify a parsing error. - - This exists to differentiate parsing errors from other code or execution errors - that also may arise inside the output parser. OutputParserExceptions will be - available to catch and handle in ways to fix the parsing error, while other - errors will be raised. - - Args: - error: The error that's being re-raised or an error message. - observation: String explanation of error which can be passed to a - model to try and remediate the issue. - llm_output: String model output which is error-ing. - send_to_llm: Whether to send the observation and llm_output back to an Agent - after an OutputParserException has been raised. This gives the underlying - model driving the agent the context that the previous output was improperly - structured, in the hopes that it will update the output to the correct - format. - """ - - def __init__( - self, - error: Any, - observation: Optional[str] = None, - llm_output: Optional[str] = None, - send_to_llm: bool = False, - ): - super(OutputParserException, self).__init__(error) - if send_to_llm: - if observation is None or llm_output is None: - raise ValueError( - "Arguments 'observation' & 'llm_output'" - " are required if 'send_to_llm' is True" - ) - self.observation = observation - self.llm_output = llm_output - self.send_to_llm = send_to_llm +__all__ = [ + "BaseLLMOutputParser", + "BaseGenerationOutputParser", + "BaseOutputParser", + "BaseTransformOutputParser", + "BaseCumulativeTransformOutputParser", + "StrOutputParser", + "OutputParserException", +] diff --git a/libs/langchain/langchain/schema/prompt.py b/libs/langchain/langchain/schema/prompt.py index 8410eb73e6e..07a61bf1cb9 100644 --- a/libs/langchain/langchain/schema/prompt.py +++ b/libs/langchain/langchain/schema/prompt.py @@ -1,28 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.prompt import PromptValue -from abc import ABC, abstractmethod -from typing import List - -from langchain.load.serializable import Serializable -from langchain.schema.messages import BaseMessage - - -class PromptValue(Serializable, ABC): - """Base abstract class for inputs to any language model. - - PromptValues can be converted to both LLM (pure text-generation) inputs and - ChatModel inputs. - """ - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - @abstractmethod - def to_string(self) -> str: - """Return prompt value as string.""" - - @abstractmethod - def to_messages(self) -> List[BaseMessage]: - """Return prompt as a list of Messages.""" +__all__ = ["PromptValue"] diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 31f0bb79857..a2dbb4fce55 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -1,228 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.prompt_template import BasePromptTemplate, format_document -import json -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union - -import yaml - -from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator -from langchain.schema.document import Document -from langchain.schema.output_parser import BaseOutputParser -from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import RunnableConfig, RunnableSerializable - - -class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): - """Base class for all prompt templates, returning a prompt.""" - - input_variables: List[str] - """A list of the names of the variables the prompt template expects.""" - input_types: Dict[str, Any] = Field(default_factory=dict) - """A dictionary of the types of the variables the prompt template expects. - If not provided, all variables are assumed to be strings.""" - output_parser: Optional[BaseOutputParser] = None - """How to parse the output of calling an LLM on this formatted prompt.""" - partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( - default_factory=dict - ) - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - def OutputType(self) -> Any: - from langchain.prompts.base import StringPromptValue - from langchain.prompts.chat import ChatPromptValueConcrete - - return Union[StringPromptValue, ChatPromptValueConcrete] - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - # This is correct, but pydantic typings/mypy don't think so. - return create_model( # type: ignore[call-overload] - "PromptInput", - **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, - ) - - def invoke( - self, input: Dict, config: Optional[RunnableConfig] = None - ) -> PromptValue: - return self._call_with_config( - lambda inner_input: self.format_prompt( - **{key: inner_input[key] for key in self.input_variables} - ), - input, - config, - run_type="prompt", - ) - - @abstractmethod - def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" - - @root_validator() - def validate_variable_names(cls, values: Dict) -> Dict: - """Validate variable names do not include restricted names.""" - if "stop" in values["input_variables"]: - raise ValueError( - "Cannot have an input variable named 'stop', as it is used internally," - " please rename." - ) - if "stop" in values["partial_variables"]: - raise ValueError( - "Cannot have an partial variable named 'stop', as it is used " - "internally, please rename." - ) - - overall = set(values["input_variables"]).intersection( - values["partial_variables"] - ) - if overall: - raise ValueError( - f"Found overlapping input and partial variables: {overall}" - ) - return values - - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: - """Return a partial of the prompt template.""" - prompt_dict = self.__dict__.copy() - prompt_dict["input_variables"] = list( - set(self.input_variables).difference(kwargs) - ) - prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} - return type(self)(**prompt_dict) - - def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: - # Get partial params: - partial_kwargs = { - k: v if isinstance(v, str) else v() - for k, v in self.partial_variables.items() - } - return {**partial_kwargs, **kwargs} - - @abstractmethod - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. - - Args: - kwargs: Any arguments to be passed to the prompt template. - - Returns: - A formatted string. - - Example: - - .. code-block:: python - - prompt.format(variable1="foo") - """ - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - raise NotImplementedError - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of prompt.""" - prompt_dict = super().dict(**kwargs) - try: - prompt_dict["_type"] = self._prompt_type - except NotImplementedError: - pass - return prompt_dict - - def save(self, file_path: Union[Path, str]) -> None: - """Save the prompt. - - Args: - file_path: Path to directory to save prompt to. - - Example: - .. code-block:: python - - prompt.save(file_path="path/prompt.yaml") - """ - if self.partial_variables: - raise ValueError("Cannot save prompt with partial variables.") - - # Fetch dictionary to save - prompt_dict = self.dict() - if "_type" not in prompt_dict: - raise NotImplementedError(f"Prompt {self} does not support saving.") - - # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path - - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - - if save_path.suffix == ".json": - with open(file_path, "w") as f: - json.dump(prompt_dict, f, indent=4) - elif save_path.suffix == ".yaml": - with open(file_path, "w") as f: - yaml.dump(prompt_dict, f, default_flow_style=False) - else: - raise ValueError(f"{save_path} must be json or yaml") - - -def format_document(doc: Document, prompt: BasePromptTemplate) -> str: - """Format a document into a string based on a prompt template. - - First, this pulls information from the document from two sources: - - 1. `page_content`: - This takes the information from the `document.page_content` - and assigns it to a variable named `page_content`. - 2. metadata: - This takes information from `document.metadata` and assigns - it to variables of the same name. - - Those variables are then passed into the `prompt` to produce a formatted string. - - Args: - doc: Document, the page_content and metadata will be used to create - the final string. - prompt: BasePromptTemplate, will be used to format the page_content - and metadata into the final string. - - Returns: - string of the document formatted. - - Example: - .. code-block:: python - - from langchain.schema import Document - from langchain.prompts import PromptTemplate - - doc = Document(page_content="This is a joke", metadata={"page": "1"}) - prompt = PromptTemplate.from_template("Page {page}: {page_content}") - format_document(doc, prompt) - >>> "Page 1: This is a joke" - """ - base_info = {"page_content": doc.page_content, **doc.metadata} - missing_metadata = set(prompt.input_variables).difference(base_info) - if len(missing_metadata) > 0: - required_metadata = [ - iv for iv in prompt.input_variables if iv != "page_content" - ] - raise ValueError( - f"Document prompt requires documents to have metadata variables: " - f"{required_metadata}. Received document with missing metadata: " - f"{list(missing_metadata)}." - ) - document_info = {k: base_info[k] for k in prompt.input_variables} - return prompt.format(**document_info) +__all__ = ["BasePromptTemplate", "format_document"] diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 180093a5115..b58b13ef749 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -1,275 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.retriever import BaseRetriever -import asyncio -import warnings -from abc import ABC, abstractmethod -from functools import partial -from inspect import signature -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from langchain.load.dump import dumpd -from langchain.schema.document import Document -from langchain.schema.runnable import RunnableConfig, RunnableSerializable - -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, - Callbacks, - ) - - -class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): - """Abstract base class for a Document retrieval system. - - A retrieval system is defined as something that can take string queries and return - the most 'relevant' Documents from some source. - - Example: - .. code-block:: python - - class TFIDFRetriever(BaseRetriever, BaseModel): - vectorizer: Any - docs: List[Document] - tfidf_array: Any - k: int = 4 - - class Config: - arbitrary_types_allowed = True - - def get_relevant_documents(self, query: str) -> List[Document]: - from sklearn.metrics.pairwise import cosine_similarity - - # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) - query_vec = self.vectorizer.transform([query]) - # Op -- (n_docs,1) -- Cosine Sim with each doc - results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) - return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] - """ # noqa: E501 - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - _new_arg_supported: bool = False - _expects_other_args: bool = False - tags: Optional[List[str]] = None - """Optional list of tags associated with the retriever. Defaults to None - These tags will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - You can use these to eg identify a specific instance of a retriever with its - use case. - """ - metadata: Optional[Dict[str, Any]] = None - """Optional metadata associated with the retriever. Defaults to None - This metadata will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - You can use these to eg identify a specific instance of a retriever with its - use case. - """ - - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - # Version upgrade for old retrievers that implemented the public - # methods directly. - if cls.get_relevant_documents != BaseRetriever.get_relevant_documents: - warnings.warn( - "Retrievers must implement abstract `_get_relevant_documents` method" - " instead of `get_relevant_documents`", - DeprecationWarning, - ) - swap = cls.get_relevant_documents - cls.get_relevant_documents = ( # type: ignore[assignment] - BaseRetriever.get_relevant_documents - ) - cls._get_relevant_documents = swap # type: ignore[assignment] - if ( - hasattr(cls, "aget_relevant_documents") - and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents - ): - warnings.warn( - "Retrievers must implement abstract `_aget_relevant_documents` method" - " instead of `aget_relevant_documents`", - DeprecationWarning, - ) - aswap = cls.aget_relevant_documents - cls.aget_relevant_documents = ( # type: ignore[assignment] - BaseRetriever.aget_relevant_documents - ) - cls._aget_relevant_documents = aswap # type: ignore[assignment] - parameters = signature(cls._get_relevant_documents).parameters - cls._new_arg_supported = parameters.get("run_manager") is not None - # If a V1 retriever broke the interface and expects additional arguments - cls._expects_other_args = ( - len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 - ) - - def invoke( - self, input: str, config: Optional[RunnableConfig] = None - ) -> List[Document]: - config = config or {} - return self.get_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - ) - - async def ainvoke( - self, - input: str, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> List[Document]: - config = config or {} - return await self.aget_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - ) - - @abstractmethod - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - """Get documents relevant to a query. - Args: - query: String to find relevant documents for - run_manager: The callbacks handler to use - Returns: - List of relevant documents - """ - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - """Asynchronously get documents relevant to a query. - Args: - query: String to find relevant documents for - run_manager: The callbacks handler to use - Returns: - List of relevant documents - """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._get_relevant_documents, run_manager=run_manager), query - ) - - def get_relevant_documents( - self, - query: str, - *, - callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Retrieve documents relevant to a query. - Args: - query: string to find relevant documents for - callbacks: Callback manager or list of callbacks - tags: Optional list of tags associated with the retriever. Defaults to None - These tags will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - metadata: Optional metadata associated with the retriever. Defaults to None - This metadata will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - Returns: - List of relevant documents - """ - from langchain.callbacks.manager import CallbackManager - - callback_manager = CallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - **kwargs, - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = self._get_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = self._get_relevant_documents(query, **_kwargs) - except Exception as e: - run_manager.on_retriever_error(e) - raise e - else: - run_manager.on_retriever_end( - result, - **kwargs, - ) - return result - - async def aget_relevant_documents( - self, - query: str, - *, - callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Asynchronously get documents relevant to a query. - Args: - query: string to find relevant documents for - callbacks: Callback manager or list of callbacks - tags: Optional list of tags associated with the retriever. Defaults to None - These tags will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - metadata: Optional metadata associated with the retriever. Defaults to None - This metadata will be associated with each call to this retriever, - and passed as arguments to the handlers defined in `callbacks`. - Returns: - List of relevant documents - """ - from langchain.callbacks.manager import AsyncCallbackManager - - callback_manager = AsyncCallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = await callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - **kwargs, - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = await self._aget_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = await self._aget_relevant_documents(query, **_kwargs) - except Exception as e: - await run_manager.on_retriever_error(e) - raise e - else: - await run_manager.on_retriever_end( - result, - **kwargs, - ) - return result +__all__ = ["BaseRetriever"] diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 3ae03b21e89..3a1f555d1c8 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -14,7 +14,7 @@ creating more responsive UX. This module contains schema and implementation of LangChain Runnables primitives. """ -from langchain.schema.runnable.base import ( +from langchain_core.runnables.base import ( Runnable, RunnableBinding, RunnableGenerator, @@ -24,12 +24,12 @@ from langchain.schema.runnable.base import ( RunnableSequence, RunnableSerializable, ) -from langchain.schema.runnable.branch import RunnableBranch -from langchain.schema.runnable.config import RunnableConfig, patch_config -from langchain.schema.runnable.fallbacks import RunnableWithFallbacks -from langchain.schema.runnable.passthrough import RunnablePassthrough -from langchain.schema.runnable.router import RouterInput, RouterRunnable -from langchain.schema.runnable.utils import ( +from langchain_core.runnables.branch import RunnableBranch +from langchain_core.runnables.config import RunnableConfig, patch_config +from langchain_core.runnables.fallbacks import RunnableWithFallbacks +from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.runnables.router import RouterInput, RouterRunnable +from langchain_core.runnables.utils import ( ConfigurableField, ConfigurableFieldMultiOption, ConfigurableFieldSingleOption, diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index fea28543951..201c1a2c816 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,3026 +1,27 @@ -from __future__ import annotations - -import asyncio -import inspect -import threading -from abc import ABC, abstractmethod -from concurrent.futures import FIRST_COMPLETED, wait -from functools import partial -from itertools import tee -from operator import itemgetter -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - Generic, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, - overload, +from langchain_core.runnables.base import ( + Runnable, + RunnableBinding, + RunnableBindingBase, + RunnableEach, + RunnableEachBase, + RunnableGenerator, + RunnableLambda, + RunnableParallel, + RunnableSequence, + RunnableSerializable, + coerce_to_runnable, ) -from typing_extensions import Literal, get_args - -if TYPE_CHECKING: - from langchain.schema.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, - ) - from langchain.schema.callbacks.tracers.log_stream import RunLog, RunLogPatch - from langchain.schema.callbacks.tracers.root_listeners import Listener - from langchain.schema.runnable.fallbacks import ( - RunnableWithFallbacks as RunnableWithFallbacksT, - ) - -from langchain.load.dump import dumpd -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import BaseModel, Field, create_model -from langchain.schema.runnable.config import ( - RunnableConfig, - acall_func_with_variable_args, - call_func_with_variable_args, - ensure_config, - get_async_callback_manager_for_config, - get_callback_manager_for_config, - get_config_list, - get_executor_for_config, - merge_configs, - patch_config, -) -from langchain.schema.runnable.utils import ( - AddableDict, - AnyConfigurableField, - ConfigurableField, - ConfigurableFieldSpec, - Input, - Output, - accepts_config, - accepts_run_manager, - gather_with_concurrency, - get_function_first_arg_dict_keys, - get_lambda_source, - get_unique_config_specs, - indent_lines_after_first, -) -from langchain.utils.aiter import atee, py_anext -from langchain.utils.iter import safetee - -Other = TypeVar("Other") - - -class Runnable(Generic[Input, Output], ABC): - """A unit of work that can be invoked, batched, streamed, transformed and composed. - - Key Methods - =========== - - * invoke/ainvoke: Transforms a single input into an output. - * batch/abatch: Efficiently transforms multiple inputs into outputs. - * stream/astream: Streams output from a single input as it's produced. - * astream_log: Streams output and selected intermediate results from an input. - - Built-in optimizations: - - * Batch: By default, batch runs invoke() in parallel using a thread pool executor. - Override to optimize batching. - - * Async: Methods with "a" suffix are asynchronous. By default, they execute - the sync counterpart using asyncio's thread pool. - Override for native async. - - All methods accept an optional config argument, which can be used to configure - execution, add tags and metadata for tracing and debugging etc. - - Runnables expose schematic information about their input, output and config via - the input_schema property, the output_schema property and config_schema method. - - LCEL and Composition - ==================== - - The LangChain Expression Language (LCEL) is a declarative way to compose Runnables - into chains. Any chain constructed this way will automatically have sync, async, - batch, and streaming support. - - The main composition primitives are RunnableSequence and RunnableParallel. - - RunnableSequence invokes a series of runnables sequentially, with one runnable's - output serving as the next's input. Construct using the `|` operator or by - passing a list of runnables to RunnableSequence. - - RunnableParallel invokes runnables concurrently, providing the same input - to each. Construct it using a dict literal within a sequence or by passing a - dict to RunnableParallel. - - - For example, - - .. code-block:: python - - from langchain.schema.runnable import RunnableLambda - - # A RunnableSequence constructed using the `|` operator - sequence = RunnableLambda(lambda x: x + 1) | RunnableLambda(lambda x: x * 2) - sequence.invoke(1) # 4 - sequence.batch([1, 2, 3]) # [4, 6, 8] - - - # A sequence that contains a RunnableParallel constructed using a dict literal - sequence = RunnableLambda(lambda x: x + 1) | { - 'mul_2': RunnableLambda(lambda x: x * 2), - 'mul_5': RunnableLambda(lambda x: x * 5) - } - sequence.invoke(1) # {'mul_2': 4, 'mul_5': 10} - - Standard Methods - ================ - - All Runnables expose additional methods that can be used to modify their behavior - (e.g., add a retry policy, add lifecycle listeners, make them configurable, etc.). - - These methods will work on any Runnable, including Runnable chains constructed - by composing other Runnables. See the individual methods for details. - - For example, - - .. code-block:: python - - from langchain.schema.runnable import RunnableLambda - - import random - - def add_one(x: int) -> int: - return x + 1 - - - def buggy_double(y: int) -> int: - '''Buggy code that will fail 70% of the time''' - if random.random() > 0.3: - print('This code failed, and will probably be retried!') - raise ValueError('Triggered buggy code') - return y * 2 - - sequence = ( - RunnableLambda(add_one) | - RunnableLambda(buggy_double).with_retry( # Retry on failure - stop_after_attempt=10, - wait_exponential_jitter=False - ) - ) - - print(sequence.input_schema.schema()) # Show inferred input schema - print(sequence.output_schema.schema()) # Show inferred output schema - print(sequence.invoke(2)) # invoke the sequence (note the retry above!!) - - Debugging and tracing - ===================== - - As the chains get longer, it can be useful to be able to see intermediate results - to debug and trace the chain. - - You can set the global debug flag to True to enable debug output for all chains: - - .. code-block:: python - - from langchain.globals import set_debug - set_debug(True) - - Alternatively, you can pass existing or custom callbacks to any given chain: - - ... code-block:: python - - from langchain.callbacks.tracers import ConsoleCallbackHandler - - chain.invoke( - ..., - config={'callbacks': [ConsoleCallbackHandler()]} - ) - - For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/ - """ - - @property - def InputType(self) -> Type[Input]: - """The type of input this runnable accepts specified as a type annotation.""" - for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] - type_args = get_args(cls) - if type_args and len(type_args) == 2: - return type_args[0] - - raise TypeError( - f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. " - "Override the InputType property to specify the input type." - ) - - @property - def OutputType(self) -> Type[Output]: - """The type of output this runnable produces specified as a type annotation.""" - for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] - type_args = get_args(cls) - if type_args and len(type_args) == 2: - return type_args[1] - - raise TypeError( - f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. " - "Override the OutputType property to specify the output type." - ) - - @property - def input_schema(self) -> Type[BaseModel]: - """The type of input this runnable accepts specified as a pydantic model.""" - return self.get_input_schema() - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - """Get a pydantic model that can be used to validate input to the runnable. - - Runnables that leverage the configurable_fields and configurable_alternatives - methods will have a dynamic input schema that depends on which - configuration the runnable is invoked with. - - This method allows to get an input schema for a specific configuration. - - Args: - config: A config to use when generating the schema. - - Returns: - A pydantic model that can be used to validate input. - """ - root_type = self.InputType - - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): - return root_type - - return create_model( - self.__class__.__name__ + "Input", __root__=(root_type, None) - ) - - @property - def output_schema(self) -> Type[BaseModel]: - """The type of output this runnable produces specified as a pydantic model.""" - return self.get_output_schema() - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - """Get a pydantic model that can be used to validate output to the runnable. - - Runnables that leverage the configurable_fields and configurable_alternatives - methods will have a dynamic output schema that depends on which - configuration the runnable is invoked with. - - This method allows to get an output schema for a specific configuration. - - Args: - config: A config to use when generating the schema. - - Returns: - A pydantic model that can be used to validate output. - """ - root_type = self.OutputType - - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): - return root_type - - return create_model( - self.__class__.__name__ + "Output", __root__=(root_type, None) - ) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - """List configurable fields for this runnable.""" - return [] - - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: - """The type of config this runnable accepts specified as a pydantic model. - - To mark a field as configurable, see the `configurable_fields` - and `configurable_alternatives` methods. - - Args: - include: A list of fields to include in the config schema. - - Returns: - A pydantic model that can be used to validate config. - """ - - class _Config: - arbitrary_types_allowed = True - - include = include or [] - config_specs = self.config_specs - configurable = ( - create_model( # type: ignore[call-overload] - "Configurable", - **{ - spec.id: ( - spec.annotation, - Field( - spec.default, title=spec.name, description=spec.description - ), - ) - for spec in config_specs - }, - ) - if config_specs - else None - ) - - return create_model( # type: ignore[call-overload] - self.__class__.__name__ + "Config", - __config__=_Config, - **({"configurable": (configurable, None)} if configurable else {}), - **{ - field_name: (field_type, None) - for field_name, field_type in RunnableConfig.__annotations__.items() - if field_name in [i for i in include if i != "configurable"] - }, - ) - - def __or__( - self, - other: Union[ - Runnable[Any, Other], - Callable[[Any], Other], - Callable[[Iterator[Any]], Iterator[Other]], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], - ) -> RunnableSerializable[Input, Other]: - """Compose this runnable with another object to create a RunnableSequence.""" - return RunnableSequence(first=self, last=coerce_to_runnable(other)) - - def __ror__( - self, - other: Union[ - Runnable[Other, Any], - Callable[[Other], Any], - Callable[[Iterator[Other]], Iterator[Any]], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], - ], - ) -> RunnableSerializable[Other, Output]: - """Compose this runnable with another object to create a RunnableSequence.""" - return RunnableSequence(first=coerce_to_runnable(other), last=self) - - """ --- Public API --- """ - - @abstractmethod - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - """Transform a single input into an output. Override to implement. - - Args: - input: The input to the runnable. - config: A config to use when invoking the runnable. - The config supports standard keys like 'tags', 'metadata' for tracing - purposes, 'max_concurrency' for controlling how much work to do - in parallel, and other keys. Please refer to the RunnableConfig - for more details. - - Returns: - The output of the runnable. - """ - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Default implementation of ainvoke, calls invoke from a thread. - - The default implementation allows usage of async code even if - the runnable did not implement a native async version of invoke. - - Subclasses should override this method if they can run asynchronously. - """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, **kwargs), input, config - ) - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - """Default implementation runs invoke in parallel using a thread pool executor. - - The default implementation of batch works well for IO bound runnables. - - Subclasses should override this method if they can batch more efficiently; - e.g., if the underlying runnable uses an API which supports a batch mode. - """ - if not inputs: - return [] - - configs = get_config_list(config, len(inputs)) - - def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: - if return_exceptions: - try: - return self.invoke(input, config, **kwargs) - except Exception as e: - return e - else: - return self.invoke(input, config, **kwargs) - - # If there's only one input, don't bother with the executor - if len(inputs) == 1: - return cast(List[Output], [invoke(inputs[0], configs[0])]) - - with get_executor_for_config(configs[0]) as executor: - return cast(List[Output], list(executor.map(invoke, inputs, configs))) - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - """Default implementation runs ainvoke in parallel using asyncio.gather. - - The default implementation of batch works well for IO bound runnables. - - Subclasses should override this method if they can batch more efficiently; - e.g., if the underlying runnable uses an API which supports a batch mode. - """ - if not inputs: - return [] - - configs = get_config_list(config, len(inputs)) - - async def ainvoke( - input: Input, config: RunnableConfig - ) -> Union[Output, Exception]: - if return_exceptions: - try: - return await self.ainvoke(input, config, **kwargs) - except Exception as e: - return e - else: - return await self.ainvoke(input, config, **kwargs) - - coros = map(ainvoke, inputs, configs) - return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - """ - Default implementation of stream, which calls invoke. - Subclasses should override this method if they support streaming output. - """ - yield self.invoke(input, config, **kwargs) - - async def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - """ - Default implementation of astream, which calls ainvoke. - Subclasses should override this method if they support streaming output. - """ - yield await self.ainvoke(input, config, **kwargs) - - @overload - def astream_log( - self, - input: Any, - config: Optional[RunnableConfig] = None, - *, - diff: Literal[True] = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[RunLogPatch]: - ... - - @overload - def astream_log( - self, - input: Any, - config: Optional[RunnableConfig] = None, - *, - diff: Literal[False], - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[RunLog]: - ... - - async def astream_log( - self, - input: Any, - config: Optional[RunnableConfig] = None, - *, - diff: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - **kwargs: Optional[Any], - ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: - """ - Stream all output from a runnable, as reported to the callback system. - This includes all inner runs of LLMs, Retrievers, Tools, etc. - - Output is streamed as Log objects, which include a list of - jsonpatch ops that describe how the state of the run has changed in each - step, and the final state of the run. - - The jsonpatch ops can be applied in order to construct state. - """ - - from langchain.callbacks.base import BaseCallbackManager - from langchain.callbacks.tracers.log_stream import ( - LogStreamCallbackHandler, - RunLog, - RunLogPatch, - ) - - # Create a stream handler that will emit Log objects - stream = LogStreamCallbackHandler( - auto_close=False, - include_names=include_names, - include_types=include_types, - include_tags=include_tags, - exclude_names=exclude_names, - exclude_types=exclude_types, - exclude_tags=exclude_tags, - ) - - # Assign the stream handler to the config - config = config or {} - callbacks = config.get("callbacks") - if callbacks is None: - config["callbacks"] = [stream] - elif isinstance(callbacks, list): - config["callbacks"] = callbacks + [stream] - elif isinstance(callbacks, BaseCallbackManager): - callbacks = callbacks.copy() - callbacks.add_handler(stream, inherit=True) - config["callbacks"] = callbacks - else: - raise ValueError( - f"Unexpected type for callbacks: {callbacks}." - "Expected None, list or AsyncCallbackManager." - ) - - # Call the runnable in streaming mode, - # add each chunk to the output stream - async def consume_astream() -> None: - try: - async for chunk in self.astream(input, config, **kwargs): - await stream.send_stream.send( - RunLogPatch( - { - "op": "add", - "path": "/streamed_output/-", - "value": chunk, - } - ) - ) - finally: - await stream.send_stream.aclose() - - # Start the runnable in a task, so we can start consuming output - task = asyncio.create_task(consume_astream()) - - try: - # Yield each chunk from the output stream - if diff: - async for log in stream: - yield log - else: - state = RunLog(state=None) # type: ignore[arg-type] - async for log in stream: - state = state + log - yield state - finally: - # Wait for the runnable to finish, if not cancelled (eg. by break) - try: - await task - except asyncio.CancelledError: - pass - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - """ - Default implementation of transform, which buffers input and then calls stream. - Subclasses should override this method if they can start producing output while - input is still being generated. - """ - final: Input - got_first_val = False - - for chunk in input: - if not got_first_val: - final = chunk - got_first_val = True - else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. - final = final + chunk # type: ignore[operator] - - if got_first_val: - yield from self.stream(final, config, **kwargs) - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - """ - Default implementation of atransform, which buffers input and calls astream. - Subclasses should override this method if they can start producing output while - input is still being generated. - """ - final: Input - got_first_val = False - - async for chunk in input: - if not got_first_val: - final = chunk - got_first_val = True - else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. - final = final + chunk # type: ignore[operator] - - if got_first_val: - async for output in self.astream(final, config, **kwargs): - yield output - - def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - """ - Bind arguments to a Runnable, returning a new Runnable. - """ - return RunnableBinding(bound=self, kwargs=kwargs, config={}) - - def with_config( - self, - config: Optional[RunnableConfig] = None, - # Sadly Unpack is not well supported by mypy so this will have to be untyped - **kwargs: Any, - ) -> Runnable[Input, Output]: - """ - Bind config to a Runnable, returning a new Runnable. - """ - return RunnableBinding( - bound=self, - config=cast( - RunnableConfig, - {**(config or {}), **kwargs}, - ), # type: ignore[misc] - kwargs={}, - ) - - def with_listeners( - self, - *, - on_start: Optional[Listener] = None, - on_end: Optional[Listener] = None, - on_error: Optional[Listener] = None, - ) -> Runnable[Input, Output]: - """ - Bind lifecycle listeners to a Runnable, returning a new Runnable. - - on_start: Called before the runnable starts running, with the Run object. - on_end: Called after the runnable finishes running, with the Run object. - on_error: Called if the runnable throws an error, with the Run object. - - The Run object contains information about the run, including its id, - type, input, output, error, start_time, end_time, and any tags or metadata - added to the run. - """ - from langchain.callbacks.tracers.root_listeners import RootListenersTracer - - return RunnableBinding( - bound=self, - config_factories=[ - lambda config: { - "callbacks": [ - RootListenersTracer( - config=config, - on_start=on_start, - on_end=on_end, - on_error=on_error, - ) - ], - } - ], - ) - - def with_types( - self, - *, - input_type: Optional[Type[Input]] = None, - output_type: Optional[Type[Output]] = None, - ) -> Runnable[Input, Output]: - """ - Bind input and output types to a Runnable, returning a new Runnable. - """ - return RunnableBinding( - bound=self, - custom_input_type=input_type, - custom_output_type=output_type, - kwargs={}, - ) - - def with_retry( - self, - *, - retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), - wait_exponential_jitter: bool = True, - stop_after_attempt: int = 3, - ) -> Runnable[Input, Output]: - """Create a new Runnable that retries the original runnable on exceptions. - - Args: - retry_if_exception_type: A tuple of exception types to retry on - wait_exponential_jitter: Whether to add jitter to the wait time - between retries - stop_after_attempt: The maximum number of attempts to make before giving up - - Returns: - A new Runnable that retries the original runnable on exceptions. - """ - from langchain.schema.runnable.retry import RunnableRetry - - return RunnableRetry( - bound=self, - kwargs={}, - config={}, - retry_exception_types=retry_if_exception_type, - wait_exponential_jitter=wait_exponential_jitter, - max_attempt_number=stop_after_attempt, - ) - - def map(self) -> Runnable[List[Input], List[Output]]: - """ - Return a new Runnable that maps a list of inputs to a list of outputs, - by calling invoke() with each input. - """ - return RunnableEach(bound=self) - - def with_fallbacks( - self, - fallbacks: Sequence[Runnable[Input, Output]], - *, - exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), - ) -> RunnableWithFallbacksT[Input, Output]: - """Add fallbacks to a runnable, returning a new Runnable. - - Args: - fallbacks: A sequence of runnables to try if the original runnable fails. - exceptions_to_handle: A tuple of exception types to handle. - - Returns: - A new Runnable that will try the original runnable, and then each - fallback in order, upon failures. - """ - from langchain.schema.runnable.fallbacks import RunnableWithFallbacks - - return RunnableWithFallbacks( - runnable=self, - fallbacks=fallbacks, - exceptions_to_handle=exceptions_to_handle, - ) - - """ --- Helper methods for Subclasses --- """ - - def _call_with_config( - self, - func: Union[ - Callable[[Input], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], - input: Input, - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> Output: - """Helper method to transform an Input value to an Output value, - with callbacks. Use this method to implement invoke() in subclasses.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - run_type=run_type, - name=config.get("run_name"), - ) - try: - output = call_func_with_variable_args( - func, input, config, run_manager, **kwargs - ) - except BaseException as e: - run_manager.on_chain_error(e) - raise - else: - run_manager.on_chain_end(dumpd(output)) - return output - - async def _acall_with_config( - self, - func: Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ], - input: Input, - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> Output: - """Helper method to transform an Input value to an Output value, - with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - run_manager = await callback_manager.on_chain_start( - dumpd(self), - input, - run_type=run_type, - name=config.get("run_name"), - ) - try: - output = await acall_func_with_variable_args( - func, input, config, run_manager, **kwargs - ) - except BaseException as e: - await run_manager.on_chain_error(e) - raise - else: - await run_manager.on_chain_end(dumpd(output)) - return output - - def _batch_with_config( - self, - func: Union[ - Callable[[List[Input]], List[Union[Exception, Output]]], - Callable[ - [List[Input], List[CallbackManagerForChainRun]], - List[Union[Exception, Output]], - ], - Callable[ - [List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], - List[Union[Exception, Output]], - ], - ], - input: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> List[Output]: - """Helper method to transform an Input value to an Output value, - with callbacks. Use this method to implement invoke() in subclasses.""" - if not input: - return [] - - configs = get_config_list(config, len(input)) - callback_managers = [get_callback_manager_for_config(c) for c in configs] - run_managers = [ - callback_manager.on_chain_start( - dumpd(self), - input, - run_type=run_type, - name=config.get("run_name"), - ) - for callback_manager, input, config in zip( - callback_managers, input, configs - ) - ] - try: - if accepts_config(func): - kwargs["config"] = [ - patch_config(c, callbacks=rm.get_child()) - for c, rm in zip(configs, run_managers) - ] - if accepts_run_manager(func): - kwargs["run_manager"] = run_managers - output = func(input, **kwargs) # type: ignore[call-arg] - except BaseException as e: - for run_manager in run_managers: - run_manager.on_chain_error(e) - if return_exceptions: - return cast(List[Output], [e for _ in input]) - else: - raise - else: - first_exception: Optional[Exception] = None - for run_manager, out in zip(run_managers, output): - if isinstance(out, Exception): - first_exception = first_exception or out - run_manager.on_chain_error(out) - else: - run_manager.on_chain_end(dumpd(out)) - if return_exceptions or first_exception is None: - return cast(List[Output], output) - else: - raise first_exception - - async def _abatch_with_config( - self, - func: Union[ - Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], - Callable[ - [List[Input], List[AsyncCallbackManagerForChainRun]], - Awaitable[List[Union[Exception, Output]]], - ], - Callable[ - [ - List[Input], - List[AsyncCallbackManagerForChainRun], - List[RunnableConfig], - ], - Awaitable[List[Union[Exception, Output]]], - ], - ], - input: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> List[Output]: - """Helper method to transform an Input value to an Output value, - with callbacks. Use this method to implement invoke() in subclasses.""" - if not input: - return [] - - configs = get_config_list(config, len(input)) - callback_managers = [get_async_callback_manager_for_config(c) for c in configs] - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( - *( - callback_manager.on_chain_start( - dumpd(self), - input, - run_type=run_type, - name=config.get("run_name"), - ) - for callback_manager, input, config in zip( - callback_managers, input, configs - ) - ) - ) - try: - if accepts_config(func): - kwargs["config"] = [ - patch_config(c, callbacks=rm.get_child()) - for c, rm in zip(configs, run_managers) - ] - if accepts_run_manager(func): - kwargs["run_manager"] = run_managers - output = await func(input, **kwargs) # type: ignore[call-arg] - except BaseException as e: - await asyncio.gather( - *(run_manager.on_chain_error(e) for run_manager in run_managers) - ) - if return_exceptions: - return cast(List[Output], [e for _ in input]) - else: - raise - else: - first_exception: Optional[Exception] = None - coros: List[Awaitable[None]] = [] - for run_manager, out in zip(run_managers, output): - if isinstance(out, Exception): - first_exception = first_exception or out - coros.append(run_manager.on_chain_error(out)) - else: - coros.append(run_manager.on_chain_end(dumpd(out))) - await asyncio.gather(*coros) - if return_exceptions or first_exception is None: - return cast(List[Output], output) - else: - raise first_exception - - def _transform_stream_with_config( - self, - input: Iterator[Input], - transformer: Union[ - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], - Callable[ - [ - Iterator[Input], - CallbackManagerForChainRun, - RunnableConfig, - ], - Iterator[Output], - ], - ], - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - """Helper method to transform an Iterator of Input values into an Iterator of - Output values, with callbacks. - Use this to implement `stream()` or `transform()` in Runnable subclasses.""" - # tee the input so we can iterate over it twice - input_for_tracing, input_for_transform = tee(input, 2) - # Start the input iterator to ensure the input runnable starts before this one - final_input: Optional[Input] = next(input_for_tracing, None) - final_input_supported = True - final_output: Optional[Output] = None - final_output_supported = True - - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - {"input": ""}, - run_type=run_type, - name=config.get("run_name"), - ) - try: - if accepts_config(transformer): - kwargs["config"] = patch_config( - config, callbacks=run_manager.get_child() - ) - if accepts_run_manager(transformer): - kwargs["run_manager"] = run_manager - iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] - for chunk in iterator: - yield chunk - if final_output_supported: - if final_output is None: - final_output = chunk - else: - try: - final_output = final_output + chunk # type: ignore - except TypeError: - final_output = None - final_output_supported = False - for ichunk in input_for_tracing: - if final_input_supported: - if final_input is None: - final_input = ichunk - else: - try: - final_input = final_input + ichunk # type: ignore - except TypeError: - final_input = None - final_input_supported = False - except BaseException as e: - run_manager.on_chain_error(e, inputs=final_input) - raise - else: - run_manager.on_chain_end(final_output, inputs=final_input) - - async def _atransform_stream_with_config( - self, - input: AsyncIterator[Input], - transformer: Union[ - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - Callable[ - [AsyncIterator[Input], AsyncCallbackManagerForChainRun], - AsyncIterator[Output], - ], - Callable[ - [ - AsyncIterator[Input], - AsyncCallbackManagerForChainRun, - RunnableConfig, - ], - AsyncIterator[Output], - ], - ], - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - """Helper method to transform an Async Iterator of Input values into an Async - Iterator of Output values, with callbacks. - Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - # tee the input so we can iterate over it twice - input_for_tracing, input_for_transform = atee(input, 2) - # Start the input iterator to ensure the input runnable starts before this one - final_input: Optional[Input] = await py_anext(input_for_tracing, None) - final_input_supported = True - final_output: Optional[Output] = None - final_output_supported = True - - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - run_manager = await callback_manager.on_chain_start( - dumpd(self), - {"input": ""}, - run_type=run_type, - name=config.get("run_name"), - ) - try: - if accepts_config(transformer): - kwargs["config"] = patch_config( - config, callbacks=run_manager.get_child() - ) - if accepts_run_manager(transformer): - kwargs["run_manager"] = run_manager - iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] - async for chunk in iterator: - yield chunk - if final_output_supported: - if final_output is None: - final_output = chunk - else: - try: - final_output = final_output + chunk # type: ignore - except TypeError: - final_output = None - final_output_supported = False - async for ichunk in input_for_tracing: - if final_input_supported: - if final_input is None: - final_input = ichunk - else: - try: - final_input = final_input + ichunk # type: ignore[operator] - except TypeError: - final_input = None - final_input_supported = False - except BaseException as e: - await run_manager.on_chain_error(e, inputs=final_input) - raise - else: - await run_manager.on_chain_end(final_output, inputs=final_input) - - -class RunnableSerializable(Serializable, Runnable[Input, Output]): - """A Runnable that can be serialized to JSON.""" - - def configurable_fields( - self, **kwargs: AnyConfigurableField - ) -> RunnableSerializable[Input, Output]: - from langchain.schema.runnable.configurable import RunnableConfigurableFields - - for key in kwargs: - if key not in self.__fields__: - raise ValueError( - f"Configuration key {key} not found in {self}: " - "available keys are {self.__fields__.keys()}" - ) - - return RunnableConfigurableFields(default=self, fields=kwargs) - - def configurable_alternatives( - self, - which: ConfigurableField, - default_key: str = "default", - **kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], - ) -> RunnableSerializable[Input, Output]: - from langchain.schema.runnable.configurable import ( - RunnableConfigurableAlternatives, - ) - - return RunnableConfigurableAlternatives( - which=which, default=self, alternatives=kwargs, default_key=default_key - ) - - -class RunnableSequence(RunnableSerializable[Input, Output]): - """A sequence of runnables, where the output of each is the input of the next. - - RunnableSequence is the most important composition operator in LangChain as it is - used in virtually every chain. - - A RunnableSequence can be instantiated directly or more commonly by using the `|` - operator where either the left or right operands (or both) must be a Runnable. - - Any RunnableSequence automatically supports sync, async, batch. - - The default implementations of `batch` and `abatch` utilize threadpools and - asyncio gather and will be faster than naive invocation of invoke or ainvoke - for IO bound runnables. - - Batching is implemented by invoking the batch method on each component of the - RunnableSequence in order. - - A RunnableSequence preserves the streaming properties of its components, so if all - components of the sequence implement a `transform` method -- which - is the method that implements the logic to map a streaming input to a streaming - output -- then the sequence will be able to stream input to output! - - If any component of the sequence does not implement transform then the - streaming will only begin after this component is run. If there are - multiple blocking components, streaming begins after the last one. - - Please note: RunnableLambdas do not support `transform` by default! So if - you need to use a RunnableLambdas be careful about where you place them in a - RunnableSequence (if you need to use the .stream()/.astream() methods). - - If you need arbitrary logic and need streaming, you can subclass - Runnable, and implement `transform` for whatever logic you need. - - Here is a simple example that uses simple functions to illustrate the use of - RunnableSequence: - - .. code-block:: python - - from langchain.schema.runnable import RunnableLambda - - def add_one(x: int) -> int: - return x + 1 - - def mul_two(x: int) -> int: - return x * 2 - - runnable_1 = RunnableLambda(add_one) - runnable_2 = RunnableLambda(mul_two) - sequence = runnable_1 | runnable_2 - # Or equivalently: - # sequence = RunnableSequence(first=runnable_1, last=runnable_2) - sequence.invoke(1) - await runnable.ainvoke(1) - - sequence.batch([1, 2, 3]) - await sequence.abatch([1, 2, 3]) - - Here's an example that uses streams JSON output generated by an LLM: - - .. code-block:: python - - from langchain.output_parsers.json import SimpleJsonOutputParser - from langchain.chat_models.openai import ChatOpenAI - - prompt = PromptTemplate.from_template( - 'In JSON format, give me a list of {topic} and their ' - 'corresponding names in French, Spanish and in a ' - 'Cat Language.' - ) - - model = ChatOpenAI() - chain = prompt | model | SimpleJsonOutputParser() - - async for chunk in chain.astream({'topic': 'colors'}): - print('-') - print(chunk, sep='', flush=True) - """ - - # The steps are broken into first, middle and last, solely for type checking - # purposes. It allows specifying the `Input` on the first type, the `Output` of - # the last type. - first: Runnable[Input, Any] - """The first runnable in the sequence.""" - middle: List[Runnable[Any, Any]] = Field(default_factory=list) - """The middle runnables in the sequence.""" - last: Runnable[Any, Output] - """The last runnable in the sequence.""" - - @property - def steps(self) -> List[Runnable[Any, Any]]: - """All the runnables that make up the sequence in order.""" - return [self.first] + self.middle + [self.last] - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - class Config: - arbitrary_types_allowed = True - - @property - def InputType(self) -> Type[Input]: - return self.first.InputType - - @property - def OutputType(self) -> Type[Output]: - return self.last.OutputType - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - from langchain.schema.runnable.passthrough import RunnableAssign - - if isinstance(self.first, RunnableAssign): - first = cast(RunnableAssign, self.first) - next_ = self.middle[0] if self.middle else self.last - next_input_schema = next_.get_input_schema(config) - if not next_input_schema.__custom_root_type__: - # it's a dict as expected - return create_model( # type: ignore[call-overload] - "RunnableSequenceInput", - **{ - k: (v.annotation, v.default) - for k, v in next_input_schema.__fields__.items() - if k not in first.mapper.steps - }, - ) - - return self.first.get_input_schema(config) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.last.get_output_schema(config) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec for step in self.steps for spec in step.config_specs - ) - - def __repr__(self) -> str: - return "\n| ".join( - repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ") - for i, s in enumerate(self.steps) - ) - - def __or__( - self, - other: Union[ - Runnable[Any, Other], - Callable[[Any], Other], - Callable[[Iterator[Any]], Iterator[Other]], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], - ) -> RunnableSerializable[Input, Other]: - if isinstance(other, RunnableSequence): - return RunnableSequence( - first=self.first, - middle=self.middle + [self.last] + [other.first] + other.middle, - last=other.last, - ) - else: - return RunnableSequence( - first=self.first, - middle=self.middle + [self.last], - last=coerce_to_runnable(other), - ) - - def __ror__( - self, - other: Union[ - Runnable[Other, Any], - Callable[[Other], Any], - Callable[[Iterator[Other]], Iterator[Any]], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], - ], - ) -> RunnableSerializable[Other, Output]: - if isinstance(other, RunnableSequence): - return RunnableSequence( - first=other.first, - middle=other.middle + [other.last] + [self.first] + self.middle, - last=self.last, - ) - else: - return RunnableSequence( - first=coerce_to_runnable(other), - middle=[self.first] + self.middle, - last=self.last, - ) - - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - # start the root run - run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - # invoke all steps in sequence - try: - for i, step in enumerate(self.steps): - input = step.invoke( - input, - # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") - ), - ) - # finish the root run - except BaseException as e: - run_manager.on_chain_error(e) - raise - else: - run_manager.on_chain_end(input) - return cast(Output, input) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - # start the root run - run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - # invoke all steps in sequence - try: - for i, step in enumerate(self.steps): - input = await step.ainvoke( - input, - # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") - ), - ) - # finish the root run - except BaseException as e: - await run_manager.on_chain_error(e) - raise - else: - await run_manager.on_chain_end(input) - return cast(Output, input) - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import CallbackManager - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers = [ - cm.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ] - - # invoke - try: - if return_exceptions: - # Track which inputs (by index) failed so far - # If an input has failed it will be present in this map, - # and the value will be the exception that was raised. - failed_inputs_map: Dict[int, Exception] = {} - for stepidx, step in enumerate(self.steps): - # Assemble the original indexes of the remaining inputs - # (i.e. the ones that haven't failed yet) - remaining_idxs = [ - i for i in range(len(configs)) if i not in failed_inputs_map - ] - # Invoke the step on the remaining inputs - inputs = step.batch( - [ - inp - for i, inp in zip(remaining_idxs, inputs) - if i not in failed_inputs_map - ], - [ - # each step a child run of the corresponding root run - patch_config( - config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") - ) - for i, (rm, config) in enumerate(zip(run_managers, configs)) - if i not in failed_inputs_map - ], - return_exceptions=return_exceptions, - **kwargs, - ) - # If an input failed, add it to the map - for i, inp in zip(remaining_idxs, inputs): - if isinstance(inp, Exception): - failed_inputs_map[i] = inp - inputs = [inp for inp in inputs if not isinstance(inp, Exception)] - # If all inputs have failed, stop processing - if len(failed_inputs_map) == len(configs): - break - - # Reassemble the outputs, inserting Exceptions for failed inputs - inputs_copy = inputs.copy() - inputs = [] - for i in range(len(configs)): - if i in failed_inputs_map: - inputs.append(cast(Input, failed_inputs_map[i])) - else: - inputs.append(inputs_copy.pop(0)) - else: - for i, step in enumerate(self.steps): - inputs = step.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config( - config, callbacks=rm.get_child(f"seq:step:{i+1}") - ) - for rm, config in zip(run_managers, configs) - ], - ) - - # finish the root runs - except BaseException as e: - for rm in run_managers: - rm.on_chain_error(e) - if return_exceptions: - return cast(List[Output], [e for _ in inputs]) - else: - raise - else: - first_exception: Optional[Exception] = None - for run_manager, out in zip(run_managers, inputs): - if isinstance(out, Exception): - first_exception = first_exception or out - run_manager.on_chain_error(out) - else: - run_manager.on_chain_end(dumpd(out)) - if return_exceptions or first_exception is None: - return cast(List[Output], inputs) - else: - raise first_exception - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import ( - AsyncCallbackManager, - ) - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( - *( - cm.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ) - ) - - # invoke .batch() on each step - # this uses batching optimizations in Runnable subclasses, like LLM - try: - if return_exceptions: - # Track which inputs (by index) failed so far - # If an input has failed it will be present in this map, - # and the value will be the exception that was raised. - failed_inputs_map: Dict[int, Exception] = {} - for stepidx, step in enumerate(self.steps): - # Assemble the original indexes of the remaining inputs - # (i.e. the ones that haven't failed yet) - remaining_idxs = [ - i for i in range(len(configs)) if i not in failed_inputs_map - ] - # Invoke the step on the remaining inputs - inputs = await step.abatch( - [ - inp - for i, inp in zip(remaining_idxs, inputs) - if i not in failed_inputs_map - ], - [ - # each step a child run of the corresponding root run - patch_config( - config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") - ) - for i, (rm, config) in enumerate(zip(run_managers, configs)) - if i not in failed_inputs_map - ], - return_exceptions=return_exceptions, - **kwargs, - ) - # If an input failed, add it to the map - for i, inp in zip(remaining_idxs, inputs): - if isinstance(inp, Exception): - failed_inputs_map[i] = inp - inputs = [inp for inp in inputs if not isinstance(inp, Exception)] - # If all inputs have failed, stop processing - if len(failed_inputs_map) == len(configs): - break - - # Reassemble the outputs, inserting Exceptions for failed inputs - inputs_copy = inputs.copy() - inputs = [] - for i in range(len(configs)): - if i in failed_inputs_map: - inputs.append(cast(Input, failed_inputs_map[i])) - else: - inputs.append(inputs_copy.pop(0)) - else: - for i, step in enumerate(self.steps): - inputs = await step.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config( - config, callbacks=rm.get_child(f"seq:step:{i+1}") - ) - for rm, config in zip(run_managers, configs) - ], - ) - # finish the root runs - except BaseException as e: - await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - if return_exceptions: - return cast(List[Output], [e for _ in inputs]) - else: - raise - else: - first_exception: Optional[Exception] = None - coros: List[Awaitable[None]] = [] - for run_manager, out in zip(run_managers, inputs): - if isinstance(out, Exception): - first_exception = first_exception or out - coros.append(run_manager.on_chain_error(out)) - else: - coros.append(run_manager.on_chain_end(dumpd(out))) - await asyncio.gather(*coros) - if return_exceptions or first_exception is None: - return cast(List[Output], inputs) - else: - raise first_exception - - def _transform( - self, - input: Iterator[Input], - run_manager: CallbackManagerForChainRun, - config: RunnableConfig, - ) -> Iterator[Output]: - steps = [self.first] + self.middle + [self.last] - - # transform the input stream of each step with the next - # steps that don't natively support transforming an input stream will - # buffer input in memory until all available, and then start emitting output - final_pipeline = cast(Iterator[Output], input) - for step in steps: - final_pipeline = step.transform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), - ), - ) - - for output in final_pipeline: - yield output - - async def _atransform( - self, - input: AsyncIterator[Input], - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - ) -> AsyncIterator[Output]: - steps = [self.first] + self.middle + [self.last] - - # stream the last steps - # transform the input stream of each step with the next - # steps that don't natively support transforming an input stream will - # buffer input in memory until all available, and then start emitting output - final_pipeline = cast(AsyncIterator[Output], input) - for step in steps: - final_pipeline = step.atransform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), - ), - ) - async for output in final_pipeline: - yield output - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - yield from self._transform_stream_with_config( - input, self._transform, config, **kwargs - ) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - yield from self.transform(iter([input]), config, **kwargs) - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - async for chunk in self._atransform_stream_with_config( - input, self._atransform, config, **kwargs - ): - yield chunk - - async def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - async def input_aiter() -> AsyncIterator[Input]: - yield input - - async for chunk in self.atransform(input_aiter(), config, **kwargs): - yield chunk - - -class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): - """ - A runnable that runs a mapping of runnables in parallel, - and returns a mapping of their outputs. - """ - - steps: Mapping[str, Runnable[Input, Any]] - - def __init__( - self, - __steps: Optional[ - Mapping[ - str, - Union[ - Runnable[Input, Any], - Callable[[Input], Any], - Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], - ], - ] - ] = None, - **kwargs: Union[ - Runnable[Input, Any], - Callable[[Input], Any], - Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], - ], - ) -> None: - merged = {**__steps} if __steps is not None else {} - merged.update(kwargs) - super().__init__( - steps={key: coerce_to_runnable(r) for key, r in merged.items()} - ) - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - class Config: - arbitrary_types_allowed = True - - @property - def InputType(self) -> Any: - for step in self.steps.values(): - if step.InputType: - return step.InputType - - return Any - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - if all( - s.get_input_schema(config).schema().get("type", "object") == "object" - for s in self.steps.values() - ): - # This is correct, but pydantic typings/mypy don't think so. - return create_model( # type: ignore[call-overload] - "RunnableParallelInput", - **{ - k: (v.annotation, v.default) - for step in self.steps.values() - for k, v in step.get_input_schema(config).__fields__.items() - if k != "__root__" - }, - ) - - return super().get_input_schema(config) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - # This is correct, but pydantic typings/mypy don't think so. - return create_model( # type: ignore[call-overload] - "RunnableParallelOutput", - **{k: (v.OutputType, None) for k, v in self.steps.items()}, - ) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec for step in self.steps.values() for spec in step.config_specs - ) - - def __repr__(self) -> str: - map_for_repr = ",\n ".join( - f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}" - for k, v in self.steps.items() - ) - return "{\n " + map_for_repr + "\n}" - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None - ) -> Dict[str, Any]: - from langchain.callbacks.manager import CallbackManager - - # setup callbacks - config = ensure_config(config) - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - # start the root run - run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - # gather results from all steps - try: - # copy to avoid issues from the caller mutating the steps during invoke() - steps = dict(self.steps) - with get_executor_for_config(config) as executor: - futures = [ - executor.submit( - step.invoke, - input, - # mark each step as a child run - patch_config( - config, - callbacks=run_manager.get_child(f"map:key:{key}"), - ), - ) - for key, step in steps.items() - ] - output = {key: future.result() for key, future in zip(steps, futures)} - # finish the root run - except BaseException as e: - run_manager.on_chain_error(e) - raise - else: - run_manager.on_chain_end(output) - return output - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Dict[str, Any]: - # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - # start the root run - run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - # gather results from all steps - try: - # copy to avoid issues from the caller mutating the steps during invoke() - steps = dict(self.steps) - results = await asyncio.gather( - *( - step.ainvoke( - input, - # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"map:key:{key}") - ), - ) - for key, step in steps.items() - ) - ) - output = {key: value for key, value in zip(steps, results)} - # finish the root run - except BaseException as e: - await run_manager.on_chain_error(e) - raise - else: - await run_manager.on_chain_end(output) - return output - - def _transform( - self, - input: Iterator[Input], - run_manager: CallbackManagerForChainRun, - config: RunnableConfig, - ) -> Iterator[AddableDict]: - # Shallow copy steps to ignore mutations while in progress - steps = dict(self.steps) - # Each step gets a copy of the input iterator, - # which is consumed in parallel in a separate thread. - input_copies = list(safetee(input, len(steps), lock=threading.Lock())) - with get_executor_for_config(config) as executor: - # Create the transform() generator for each step - named_generators = [ - ( - name, - step.transform( - input_copies.pop(), - patch_config( - config, callbacks=run_manager.get_child(f"map:key:{name}") - ), - ), - ) - for name, step in steps.items() - ] - # Start the first iteration of each generator - futures = { - executor.submit(next, generator): (step_name, generator) - for step_name, generator in named_generators - } - # Yield chunks from each as they become available, - # and start the next iteration of that generator that yielded it. - # When all generators are exhausted, stop. - while futures: - completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED) - for future in completed_futures: - (step_name, generator) = futures.pop(future) - try: - chunk = AddableDict({step_name: future.result()}) - yield chunk - futures[executor.submit(next, generator)] = ( - step_name, - generator, - ) - except StopIteration: - pass - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: - yield from self._transform_stream_with_config( - input, self._transform, config, **kwargs - ) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Dict[str, Any]]: - yield from self.transform(iter([input]), config) - - async def _atransform( - self, - input: AsyncIterator[Input], - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - ) -> AsyncIterator[AddableDict]: - # Shallow copy steps to ignore mutations while in progress - steps = dict(self.steps) - # Each step gets a copy of the input iterator, - # which is consumed in parallel in a separate thread. - input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) - # Create the transform() generator for each step - named_generators = [ - ( - name, - step.atransform( - input_copies.pop(), - patch_config( - config, callbacks=run_manager.get_child(f"map:key:{name}") - ), - ), - ) - for name, step in steps.items() - ] - - # Wrap in a coroutine to satisfy linter - async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: - return await py_anext(generator) - - # Start the first iteration of each generator - tasks = { - asyncio.create_task(get_next_chunk(generator)): (step_name, generator) - for step_name, generator in named_generators - } - # Yield chunks from each as they become available, - # and start the next iteration of the generator that yielded it. - # When all generators are exhausted, stop. - while tasks: - completed_tasks, _ = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in completed_tasks: - (step_name, generator) = tasks.pop(task) - try: - chunk = AddableDict({step_name: task.result()}) - yield chunk - new_task = asyncio.create_task(get_next_chunk(generator)) - tasks[new_task] = (step_name, generator) - except StopAsyncIteration: - pass - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: - async for chunk in self._atransform_stream_with_config( - input, self._atransform, config, **kwargs - ): - yield chunk - - async def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Dict[str, Any]]: - async def input_aiter() -> AsyncIterator[Input]: - yield input - - async for chunk in self.atransform(input_aiter(), config): - yield chunk - - -# We support both names -RunnableMap = RunnableParallel - - -class RunnableGenerator(Runnable[Input, Output]): - """ - A runnable that runs a generator function. - """ - - def __init__( - self, - transform: Union[ - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - ], - atransform: Optional[ - Callable[[AsyncIterator[Input]], AsyncIterator[Output]] - ] = None, - ) -> None: - if atransform is not None: - self._atransform = atransform - - if inspect.isasyncgenfunction(transform): - self._atransform = transform - elif inspect.isgeneratorfunction(transform): - self._transform = transform - else: - raise TypeError( - "Expected a generator function type for `transform`." - f"Instead got an unsupported type: {type(transform)}" - ) - - @property - def InputType(self) -> Any: - func = getattr(self, "_transform", None) or getattr(self, "_atransform") - try: - params = inspect.signature(func).parameters - first_param = next(iter(params.values()), None) - if first_param and first_param.annotation != inspect.Parameter.empty: - return getattr(first_param.annotation, "__args__", (Any,))[0] - else: - return Any - except ValueError: - return Any - - @property - def OutputType(self) -> Any: - func = getattr(self, "_transform", None) or getattr(self, "_atransform") - try: - sig = inspect.signature(func) - return ( - getattr(sig.return_annotation, "__args__", (Any,))[0] - if sig.return_annotation != inspect.Signature.empty - else Any - ) - except ValueError: - return Any - - def __eq__(self, other: Any) -> bool: - if isinstance(other, RunnableGenerator): - if hasattr(self, "_transform") and hasattr(other, "_transform"): - return self._transform == other._transform - elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): - return self._atransform == other._atransform - else: - return False - else: - return False - - def __repr__(self) -> str: - return "RunnableGenerator(...)" - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Output]: - return self._transform_stream_with_config( - input, self._transform, config, **kwargs - ) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Output]: - return self.transform(iter([input]), config, **kwargs) - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - final = None - for output in self.stream(input, config, **kwargs): - if final is None: - final = output - else: - final = final + output - return cast(Output, final) - - def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Output]: - if not hasattr(self, "_atransform"): - raise NotImplementedError("This runnable does not support async methods.") - - return self._atransform_stream_with_config( - input, self._atransform, config, **kwargs - ) - - def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Output]: - async def input_aiter() -> AsyncIterator[Input]: - yield input - - return self.atransform(input_aiter(), config, **kwargs) - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - final = None - async for output in self.astream(input, config, **kwargs): - if final is None: - final = output - else: - final = final + output - return cast(Output, final) - - -class RunnableLambda(Runnable[Input, Output]): - """RunnableLambda converts a python callable into a Runnable. - - Wrapping a callable in a RunnableLambda makes the callable usable - within either a sync or async context. - - RunnableLambda can be composed as any other Runnable and provides - seamless integration with LangChain tracing. - - Examples: - - .. code-block:: python - - # This is a RunnableLambda - from langchain.schema.runnable import RunnableLambda - - def add_one(x: int) -> int: - return x + 1 - - runnable = RunnableLambda(add_one) - - runnable.invoke(1) # returns 2 - runnable.batch([1, 2, 3]) # returns [2, 3, 4] - - # Async is supported by default by delegating to the sync implementation - await runnable.ainvoke(1) # returns 2 - await runnable.abatch([1, 2, 3]) # returns [2, 3, 4] - - - # Alternatively, can provide both synd and sync implementations - async def add_one_async(x: int) -> int: - return x + 1 - - runnable = RunnableLambda(add_one, afunc=add_one_async) - runnable.invoke(1) # Uses add_one - await runnable.ainvoke(1) # Uses add_one_async - """ - - def __init__( - self, - func: Union[ - Union[ - Callable[[Input], Output], - Callable[[Input, RunnableConfig], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], - Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ], - ], - afunc: Optional[ - Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ] - ] = None, - ) -> None: - """Create a RunnableLambda from a callable, and async callable or both. - - Accepts both sync and async variants to allow providing efficient - implementations for sync and async execution. - - Args: - func: Either sync or async callable - afunc: An async callable that takes an input and returns an output. - """ - if afunc is not None: - self.afunc = afunc - - if inspect.iscoroutinefunction(func): - if afunc is not None: - raise TypeError( - "Func was provided as a coroutine function, but afunc was " - "also provided. If providing both, func should be a regular " - "function to avoid ambiguity." - ) - self.afunc = func - elif callable(func): - self.func = cast(Callable[[Input], Output], func) - else: - raise TypeError( - "Expected a callable type for `func`." - f"Instead got an unsupported type: {type(func)}" - ) - - @property - def InputType(self) -> Any: - """The type of the input to this runnable.""" - func = getattr(self, "func", None) or getattr(self, "afunc") - try: - params = inspect.signature(func).parameters - first_param = next(iter(params.values()), None) - if first_param and first_param.annotation != inspect.Parameter.empty: - return first_param.annotation - else: - return Any - except ValueError: - return Any - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - """The pydantic schema for the input to this runnable.""" - func = getattr(self, "func", None) or getattr(self, "afunc") - - if isinstance(func, itemgetter): - # This is terrible, but afaict it's not possible to access _items - # on itemgetter objects, so we have to parse the repr - items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ") - if all( - item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items - ): - # It's a dict, lol - return create_model( - "RunnableLambdaInput", - **{item[1:-1]: (Any, None) for item in items}, # type: ignore - ) - else: - return create_model("RunnableLambdaInput", __root__=(List[Any], None)) - - if self.InputType != Any: - return super().get_input_schema(config) - - if dict_keys := get_function_first_arg_dict_keys(func): - return create_model( - "RunnableLambdaInput", - **{key: (Any, None) for key in dict_keys}, # type: ignore - ) - - return super().get_input_schema(config) - - @property - def OutputType(self) -> Any: - """The type of the output of this runnable as a type annotation.""" - func = getattr(self, "func", None) or getattr(self, "afunc") - try: - sig = inspect.signature(func) - return ( - sig.return_annotation - if sig.return_annotation != inspect.Signature.empty - else Any - ) - except ValueError: - return Any - - def __eq__(self, other: Any) -> bool: - if isinstance(other, RunnableLambda): - if hasattr(self, "func") and hasattr(other, "func"): - return self.func == other.func - elif hasattr(self, "afunc") and hasattr(other, "afunc"): - return self.afunc == other.afunc - else: - return False - else: - return False - - def __repr__(self) -> str: - """A string representation of this runnable.""" - if hasattr(self, "func"): - return f"RunnableLambda({get_lambda_source(self.func) or '...'})" - elif hasattr(self, "afunc"): - return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})" - else: - return "RunnableLambda(...)" - - def _invoke( - self, - input: Input, - run_manager: CallbackManagerForChainRun, - config: RunnableConfig, - **kwargs: Any, - ) -> Output: - output = call_func_with_variable_args( - self.func, input, config, run_manager, **kwargs - ) - # If the output is a runnable, invoke it - if isinstance(output, Runnable): - recursion_limit = config["recursion_limit"] - if recursion_limit <= 0: - raise RecursionError( - f"Recursion limit reached when invoking {self} with input {input}." - ) - output = output.invoke( - input, - patch_config( - config, - callbacks=run_manager.get_child(), - recursion_limit=recursion_limit - 1, - ), - ) - return output - - async def _ainvoke( - self, - input: Input, - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - **kwargs: Any, - ) -> Output: - output = await acall_func_with_variable_args( - self.afunc, input, config, run_manager, **kwargs - ) - # If the output is a runnable, invoke it - if isinstance(output, Runnable): - recursion_limit = config["recursion_limit"] - if recursion_limit <= 0: - raise RecursionError( - f"Recursion limit reached when invoking {self} with input {input}." - ) - output = await output.ainvoke( - input, - patch_config( - config, - callbacks=run_manager.get_child(), - recursion_limit=recursion_limit - 1, - ), - ) - return output - - def _config( - self, config: Optional[RunnableConfig], callable: Callable[..., Any] - ) -> RunnableConfig: - config = config or {} - - if config.get("run_name") is None: - try: - run_name = callable.__name__ - except AttributeError: - run_name = None - if run_name is not None: - return patch_config(config, run_name=run_name) - - return config - - def invoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - """Invoke this runnable synchronously.""" - if hasattr(self, "func"): - return self._call_with_config( - self._invoke, - input, - self._config(config, self.func), - **kwargs, - ) - else: - raise TypeError( - "Cannot invoke a coroutine function synchronously." - "Use `ainvoke` instead." - ) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - """Invoke this runnable asynchronously.""" - if hasattr(self, "afunc"): - return await self._acall_with_config( - self._ainvoke, - input, - self._config(config, self.afunc), - **kwargs, - ) - else: - # Delegating to super implementation of ainvoke. - # Uses asyncio executor to run the sync version (invoke) - return await super().ainvoke(input, config) - - -class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): - """ - A runnable that delegates calls to another runnable - with each element of the input sequence. - - Use only if creating a new RunnableEach subclass with different __init__ args. - """ - - bound: Runnable[Input, Output] - - class Config: - arbitrary_types_allowed = True - - @property - def InputType(self) -> Any: - return List[self.bound.InputType] # type: ignore[name-defined] - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return create_model( - "RunnableEachInput", - __root__=( - List[self.bound.get_input_schema(config)], # type: ignore - None, - ), - ) - - @property - def OutputType(self) -> Type[List[Output]]: - return List[self.bound.OutputType] # type: ignore[name-defined] - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - schema = self.bound.get_output_schema(config) - return create_model( - "RunnableEachOutput", - __root__=( - List[schema], # type: ignore - None, - ), - ) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return self.bound.config_specs - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - def _invoke( - self, - inputs: List[Input], - run_manager: CallbackManagerForChainRun, - config: RunnableConfig, - **kwargs: Any, - ) -> List[Output]: - return self.bound.batch( - inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs - ) - - def invoke( - self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> List[Output]: - return self._call_with_config(self._invoke, input, config, **kwargs) - - async def _ainvoke( - self, - inputs: List[Input], - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - **kwargs: Any, - ) -> List[Output]: - return await self.bound.abatch( - inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs - ) - - async def ainvoke( - self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> List[Output]: - return await self._acall_with_config(self._ainvoke, input, config, **kwargs) - - -class RunnableEach(RunnableEachBase[Input, Output]): - """ - A runnable that delegates calls to another runnable - with each element of the input sequence. - """ - - def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: - return RunnableEach(bound=self.bound.bind(**kwargs)) - - def with_config( - self, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> RunnableEach[Input, Output]: - return RunnableEach(bound=self.bound.with_config(config, **kwargs)) - - def with_listeners( - self, - *, - on_start: Optional[Listener] = None, - on_end: Optional[Listener] = None, - on_error: Optional[Listener] = None, - ) -> RunnableEach[Input, Output]: - """ - Bind lifecycle listeners to a Runnable, returning a new Runnable. - - on_start: Called before the runnable starts running, with the Run object. - on_end: Called after the runnable finishes running, with the Run object. - on_error: Called if the runnable throws an error, with the Run object. - - The Run object contains information about the run, including its id, - type, input, output, error, start_time, end_time, and any tags or metadata - added to the run. - """ - return RunnableEach( - bound=self.bound.with_listeners( - on_start=on_start, on_end=on_end, on_error=on_error - ) - ) - - -class RunnableBindingBase(RunnableSerializable[Input, Output]): - """ - A runnable that delegates calls to another runnable with a set of kwargs. - - Use only if creating a new RunnableBinding subclass with different __init__ args. - """ - - bound: Runnable[Input, Output] - - kwargs: Mapping[str, Any] = Field(default_factory=dict) - - config: RunnableConfig = Field(default_factory=dict) - - config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field( - default_factory=list - ) - - # Union[Type[Input], BaseModel] + things like List[str] - custom_input_type: Optional[Any] = None - # Union[Type[Output], BaseModel] + things like List[str] - custom_output_type: Optional[Any] = None - - class Config: - arbitrary_types_allowed = True - - def __init__( - self, - *, - bound: Runnable[Input, Output], - kwargs: Optional[Mapping[str, Any]] = None, - config: Optional[RunnableConfig] = None, - config_factories: Optional[ - List[Callable[[RunnableConfig], RunnableConfig]] - ] = None, - custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, - custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, - **other_kwargs: Any, - ) -> None: - config = config or {} - # config_specs contains the list of valid `configurable` keys - if configurable := config.get("configurable", None): - allowed_keys = set(s.id for s in bound.config_specs) - for key in configurable: - if key not in allowed_keys: - raise ValueError( - f"Configurable key '{key}' not found in runnable with" - f" config keys: {allowed_keys}" - ) - super().__init__( - bound=bound, - kwargs=kwargs or {}, - config=config or {}, - config_factories=config_factories or [], - custom_input_type=custom_input_type, - custom_output_type=custom_output_type, - **other_kwargs, - ) - - @property - def InputType(self) -> Type[Input]: - return ( - cast(Type[Input], self.custom_input_type) - if self.custom_input_type is not None - else self.bound.InputType - ) - - @property - def OutputType(self) -> Type[Output]: - return ( - cast(Type[Output], self.custom_output_type) - if self.custom_output_type is not None - else self.bound.OutputType - ) - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - if self.custom_input_type is not None: - return super().get_input_schema(config) - return self.bound.get_input_schema(merge_configs(self.config, config)) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - if self.custom_output_type is not None: - return super().get_output_schema(config) - return self.bound.get_output_schema(merge_configs(self.config, config)) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return self.bound.config_specs - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: - config = merge_configs(self.config, *configs) - return merge_configs(config, *(f(config) for f in self.config_factories)) - - def invoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - return self.bound.invoke( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - return await self.bound.ainvoke( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ) - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - if isinstance(config, list): - configs = cast( - List[RunnableConfig], - [self._merge_configs(conf) for conf in config], - ) - else: - configs = [self._merge_configs(config) for _ in range(len(inputs))] - return self.bound.batch( - inputs, - configs, - return_exceptions=return_exceptions, - **{**self.kwargs, **kwargs}, - ) - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - if isinstance(config, list): - configs = cast( - List[RunnableConfig], - [self._merge_configs(conf) for conf in config], - ) - else: - configs = [self._merge_configs(config) for _ in range(len(inputs))] - return await self.bound.abatch( - inputs, - configs, - return_exceptions=return_exceptions, - **{**self.kwargs, **kwargs}, - ) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - yield from self.bound.stream( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ) - - async def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - async for item in self.bound.astream( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ): - yield item - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Output]: - yield from self.bound.transform( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ) - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Output]: - async for item in self.bound.atransform( - input, - self._merge_configs(config), - **{**self.kwargs, **kwargs}, - ): - yield item - - -RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig) - - -class RunnableBinding(RunnableBindingBase[Input, Output]): - """ - A runnable that delegates calls to another runnable with a set of kwargs. - """ - - def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - config=self.config, - kwargs={**self.kwargs, **kwargs}, - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_config( - self, - config: Optional[RunnableConfig] = None, - # Sadly Unpack is not well supported by mypy so this will have to be untyped - **kwargs: Any, - ) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_listeners( - self, - *, - on_start: Optional[Listener] = None, - on_end: Optional[Listener] = None, - on_error: Optional[Listener] = None, - ) -> Runnable[Input, Output]: - """ - Bind lifecycle listeners to a Runnable, returning a new Runnable. - - on_start: Called before the runnable starts running, with the Run object. - on_end: Called after the runnable finishes running, with the Run object. - on_error: Called if the runnable throws an error, with the Run object. - - The Run object contains information about the run, including its id, - type, input, output, error, start_time, end_time, and any tags or metadata - added to the run. - """ - from langchain.callbacks.tracers.root_listeners import RootListenersTracer - - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=self.config, - config_factories=[ - lambda config: { - "callbacks": [ - RootListenersTracer( - config=config, - on_start=on_start, - on_end=on_end, - on_error=on_error, - ) - ], - } - ], - custom_input_type=self.custom_input_type, - custom_output_type=self.custom_output_type, - ) - - def with_types( - self, - input_type: Optional[Union[Type[Input], BaseModel]] = None, - output_type: Optional[Union[Type[Output], BaseModel]] = None, - ) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound, - kwargs=self.kwargs, - config=self.config, - custom_input_type=input_type - if input_type is not None - else self.custom_input_type, - custom_output_type=output_type - if output_type is not None - else self.custom_output_type, - ) - - def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__( - bound=self.bound.with_retry(**kwargs), - kwargs=self.kwargs, - config=self.config, - ) - - -RunnableLike = Union[ - Runnable[Input, Output], - Callable[[Input], Output], - Callable[[Input], Awaitable[Output]], - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - Mapping[str, Any], +__all__ = [ + "Runnable", + "RunnableSerializable", + "RunnableSequence", + "RunnableParallel", + "RunnableGenerator", + "RunnableLambda", + "RunnableEachBase", + "RunnableEach", + "RunnableBindingBase", + "RunnableBinding", + "coerce_to_runnable", ] - - -def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: - """Coerce a runnable-like object into a Runnable. - - Args: - thing: A runnable-like object. - - Returns: - A Runnable. - """ - if isinstance(thing, Runnable): - return thing - elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): - return RunnableGenerator(thing) - elif callable(thing): - return RunnableLambda(cast(Callable[[Input], Output], thing)) - elif isinstance(thing, dict): - return cast(Runnable[Input, Output], RunnableParallel(thing)) - else: - raise TypeError( - f"Expected a Runnable, callable or dict." - f"Instead got an unsupported type: {type(thing)}" - ) diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index 11fae9e0808..ed83f197aa0 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -1,254 +1,3 @@ -from typing import ( - Any, - Awaitable, - Callable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) +from langchain_core.runnables.branch import RunnableBranch -from langchain.load.dump import dumpd -from langchain.pydantic_v1 import BaseModel -from langchain.schema.runnable.base import ( - Runnable, - RunnableLike, - RunnableSerializable, - coerce_to_runnable, -) -from langchain.schema.runnable.config import ( - RunnableConfig, - ensure_config, - get_callback_manager_for_config, - patch_config, -) -from langchain.schema.runnable.utils import ( - ConfigurableFieldSpec, - Input, - Output, - get_unique_config_specs, -) - - -class RunnableBranch(RunnableSerializable[Input, Output]): - """A Runnable that selects which branch to run based on a condition. - - The runnable is initialized with a list of (condition, runnable) pairs and - a default branch. - - When operating on an input, the first condition that evaluates to True is - selected, and the corresponding runnable is run on the input. - - If no condition evaluates to True, the default branch is run on the input. - - Examples: - - .. code-block:: python - - from langchain.schema.runnable import RunnableBranch - - branch = RunnableBranch( - (lambda x: isinstance(x, str), lambda x: x.upper()), - (lambda x: isinstance(x, int), lambda x: x + 1), - (lambda x: isinstance(x, float), lambda x: x * 2), - lambda x: "goodbye", - ) - - branch.invoke("hello") # "HELLO" - branch.invoke(None) # "goodbye" - """ - - branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] - default: Runnable[Input, Output] - - def __init__( - self, - *branches: Union[ - Tuple[ - Union[ - Runnable[Input, bool], - Callable[[Input], bool], - Callable[[Input], Awaitable[bool]], - ], - RunnableLike, - ], - RunnableLike, # To accommodate the default branch - ], - ) -> None: - """A Runnable that runs one of two branches based on a condition.""" - if len(branches) < 2: - raise ValueError("RunnableBranch requires at least two branches") - - default = branches[-1] - - if not isinstance( - default, - (Runnable, Callable, Mapping), # type: ignore[arg-type] - ): - raise TypeError( - "RunnableBranch default must be runnable, callable or mapping." - ) - - default_ = cast( - Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) - ) - - _branches = [] - - for branch in branches[:-1]: - if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] - raise TypeError( - f"RunnableBranch branches must be " - f"tuples or lists, not {type(branch)}" - ) - - if not len(branch) == 2: - raise ValueError( - f"RunnableBranch branches must be " - f"tuples or lists of length 2, not {len(branch)}" - ) - condition, runnable = branch - condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) - runnable = coerce_to_runnable(runnable) - _branches.append((condition, runnable)) - - super().__init__(branches=_branches, default=default_) - - class Config: - arbitrary_types_allowed = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """RunnableBranch is serializable if all its branches are serializable.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """The namespace of a RunnableBranch is the namespace of its default branch.""" - return cls.__module__.split(".")[:-1] - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - runnables = ( - [self.default] - + [r for _, r in self.branches] - + [r for r, _ in self.branches] - ) - - for runnable in runnables: - if runnable.get_input_schema(config).schema().get("type") is not None: - return runnable.get_input_schema(config) - - return super().get_input_schema(config) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec - for step in ( - [self.default] - + [r for _, r in self.branches] - + [r for r, _ in self.branches] - ) - for spec in step.config_specs - ) - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """First evaluates the condition, then delegate to true or false branch.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = condition.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) - - if expression_value: - output = runnable.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - **kwargs, - ) - break - else: - output = self.default.invoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - **kwargs, - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Async version of invoke.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = await condition.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) - - if expression_value: - output = await runnable.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - **kwargs, - ) - break - else: - output = await self.default.ainvoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - **kwargs, - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output +__all__ = ["RunnableBranch"] diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 869c413ebf5..0906d862cf3 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,401 +1,25 @@ -from __future__ import annotations - -from concurrent.futures import Executor, ThreadPoolExecutor -from contextlib import contextmanager -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Dict, - Generator, - List, - Optional, - Union, - cast, +from langchain_core.runnables.config import ( + EmptyDict, + RunnableConfig, + call_func_with_variable_args, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, + get_config_list, + get_executor_for_config, + merge_configs, + patch_config, ) -from typing_extensions import TypedDict - -from langchain.schema.runnable.utils import ( - Input, - Output, - accepts_config, - accepts_run_manager, -) - -if TYPE_CHECKING: - from langchain.callbacks.base import BaseCallbackManager, Callbacks - from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForChainRun, - CallbackManager, - CallbackManagerForChainRun, - ) -else: - # Pydantic validates through typed dicts, but - # the callbacks need forward refs updated - Callbacks = Optional[Union[List, Any]] - - -class EmptyDict(TypedDict, total=False): - """Empty dict type.""" - - pass - - -class RunnableConfig(TypedDict, total=False): - """Configuration for a Runnable.""" - - tags: List[str] - """ - Tags for this call and any sub-calls (eg. a Chain calling an LLM). - You can use these to filter calls. - """ - - metadata: Dict[str, Any] - """ - Metadata for this call and any sub-calls (eg. a Chain calling an LLM). - Keys should be strings, values should be JSON-serializable. - """ - - callbacks: Callbacks - """ - Callbacks for this call and any sub-calls (eg. a Chain calling an LLM). - Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. - """ - - run_name: str - """ - Name for the tracer run for this call. Defaults to the name of the class. - """ - - max_concurrency: Optional[int] - """ - Maximum number of parallel calls to make. If not provided, defaults to - ThreadPoolExecutor's default. - """ - - recursion_limit: int - """ - Maximum number of times a call can recurse. If not provided, defaults to 25. - """ - - configurable: Dict[str, Any] - """ - Runtime values for attributes previously made configurable on this Runnable, - or sub-Runnables, through .configurable_fields() or .configurable_alternatives(). - Check .output_schema() for a description of the attributes that have been made - configurable. - """ - - -def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: - """Ensure that a config is a dict with all keys present. - - Args: - config (Optional[RunnableConfig], optional): The config to ensure. - Defaults to None. - - Returns: - RunnableConfig: The ensured config. - """ - empty = RunnableConfig( - tags=[], - metadata={}, - callbacks=None, - recursion_limit=25, - ) - if config is not None: - empty.update( - cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) - ) - return empty - - -def get_config_list( - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int -) -> List[RunnableConfig]: - """Get a list of configs from a single config or a list of configs. - - It is useful for subclasses overriding batch() or abatch(). - - Args: - config (Optional[Union[RunnableConfig, List[RunnableConfig]]]): - The config or list of configs. - length (int): The length of the list. - - Returns: - List[RunnableConfig]: The list of configs. - - Raises: - ValueError: If the length of the list is not equal to the length of the inputs. - - """ - if length < 0: - raise ValueError(f"length must be >= 0, but got {length}") - if isinstance(config, list) and len(config) != length: - raise ValueError( - f"config must be a list of the same length as inputs, " - f"but got {len(config)} configs for {length} inputs" - ) - - return ( - list(map(ensure_config, config)) - if isinstance(config, list) - else [ensure_config(config) for _ in range(length)] - ) - - -def patch_config( - config: Optional[RunnableConfig], - *, - callbacks: Optional[BaseCallbackManager] = None, - recursion_limit: Optional[int] = None, - max_concurrency: Optional[int] = None, - run_name: Optional[str] = None, - configurable: Optional[Dict[str, Any]] = None, -) -> RunnableConfig: - """Patch a config with new values. - - Args: - config (Optional[RunnableConfig]): The config to patch. - copy_locals (bool, optional): Whether to copy locals. Defaults to False. - callbacks (Optional[BaseCallbackManager], optional): The callbacks to set. - Defaults to None. - recursion_limit (Optional[int], optional): The recursion limit to set. - Defaults to None. - max_concurrency (Optional[int], optional): The max concurrency to set. - Defaults to None. - run_name (Optional[str], optional): The run name to set. Defaults to None. - configurable (Optional[Dict[str, Any]], optional): The configurable to set. - Defaults to None. - - Returns: - RunnableConfig: The patched config. - """ - config = ensure_config(config) - if callbacks is not None: - # If we're replacing callbacks, we need to unset run_name - # As that should apply only to the same run as the original callbacks - config["callbacks"] = callbacks - if "run_name" in config: - del config["run_name"] - if recursion_limit is not None: - config["recursion_limit"] = recursion_limit - if max_concurrency is not None: - config["max_concurrency"] = max_concurrency - if run_name is not None: - config["run_name"] = run_name - if configurable is not None: - config["configurable"] = {**config.get("configurable", {}), **configurable} - return config - - -def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: - """Merge multiple configs into one. - - Args: - *configs (Optional[RunnableConfig]): The configs to merge. - - Returns: - RunnableConfig: The merged config. - """ - base: RunnableConfig = {} - # Even though the keys aren't literals, this is correct - # because both dicts are the same type - for config in (c for c in configs if c is not None): - for key in config: - if key == "metadata": - base[key] = { # type: ignore - **base.get(key, {}), # type: ignore - **(config.get(key) or {}), # type: ignore - } - elif key == "tags": - base[key] = list( # type: ignore - set(base.get(key, []) + (config.get(key) or [])), # type: ignore - ) - elif key == "configurable": - base[key] = { # type: ignore - **base.get(key, {}), # type: ignore - **(config.get(key) or {}), # type: ignore - } - elif key == "callbacks": - base_callbacks = base.get("callbacks") - these_callbacks = config["callbacks"] - # callbacks can be either None, list[handler] or manager - # so merging two callbacks values has 6 cases - if isinstance(these_callbacks, list): - if base_callbacks is None: - base["callbacks"] = these_callbacks - elif isinstance(base_callbacks, list): - base["callbacks"] = base_callbacks + these_callbacks - else: - # base_callbacks is a manager - mngr = base_callbacks.copy() - for callback in these_callbacks: - mngr.add_handler(callback, inherit=True) - base["callbacks"] = mngr - elif these_callbacks is not None: - # these_callbacks is a manager - if base_callbacks is None: - base["callbacks"] = these_callbacks - elif isinstance(base_callbacks, list): - mngr = these_callbacks.copy() - for callback in base_callbacks: - mngr.add_handler(callback, inherit=True) - base["callbacks"] = mngr - else: - # base_callbacks is also a manager - base["callbacks"] = base_callbacks.__class__( - parent_run_id=base_callbacks.parent_run_id - or these_callbacks.parent_run_id, - handlers=base_callbacks.handlers + these_callbacks.handlers, - inheritable_handlers=base_callbacks.inheritable_handlers - + these_callbacks.inheritable_handlers, - tags=list(set(base_callbacks.tags + these_callbacks.tags)), - inheritable_tags=list( - set( - base_callbacks.inheritable_tags - + these_callbacks.inheritable_tags - ) - ), - metadata={ - **base_callbacks.metadata, - **these_callbacks.metadata, - }, - ) - else: - base[key] = config[key] or base.get(key) # type: ignore - return base - - -def call_func_with_variable_args( - func: Union[ - Callable[[Input], Output], - Callable[[Input, RunnableConfig], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], - input: Input, - config: RunnableConfig, - run_manager: Optional[CallbackManagerForChainRun] = None, - **kwargs: Any, -) -> Output: - """Call function that may optionally accept a run_manager and/or config. - - Args: - func (Union[Callable[[Input], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]): - The function to call. - input (Input): The input to the function. - run_manager (CallbackManagerForChainRun): The run manager to - pass to the function. - config (RunnableConfig): The config to pass to the function. - **kwargs (Any): The keyword arguments to pass to the function. - - Returns: - Output: The output of the function. - """ - if accepts_config(func): - if run_manager is not None: - kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) - else: - kwargs["config"] = config - if run_manager is not None and accepts_run_manager(func): - kwargs["run_manager"] = run_manager - return func(input, **kwargs) # type: ignore[call-arg] - - -async def acall_func_with_variable_args( - func: Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ], - input: Input, - config: RunnableConfig, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - **kwargs: Any, -) -> Output: - """Call function that may optionally accept a run_manager and/or config. - - Args: - func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input, - AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input, - AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]): - The function to call. - input (Input): The input to the function. - run_manager (AsyncCallbackManagerForChainRun): The run manager - to pass to the function. - config (RunnableConfig): The config to pass to the function. - **kwargs (Any): The keyword arguments to pass to the function. - - Returns: - Output: The output of the function. - """ - if accepts_config(func): - if run_manager is not None: - kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) - else: - kwargs["config"] = config - if run_manager is not None and accepts_run_manager(func): - kwargs["run_manager"] = run_manager - return await func(input, **kwargs) # type: ignore[call-arg] - - -def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: - """Get a callback manager for a config. - - Args: - config (RunnableConfig): The config. - - Returns: - CallbackManager: The callback manager. - """ - from langchain.callbacks.manager import CallbackManager - - return CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - -def get_async_callback_manager_for_config( - config: RunnableConfig, -) -> AsyncCallbackManager: - """Get an async callback manager for a config. - - Args: - config (RunnableConfig): The config. - - Returns: - AsyncCallbackManager: The async callback manager. - """ - from langchain.callbacks.manager import AsyncCallbackManager - - return AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - -@contextmanager -def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]: - """Get an executor for a config. - - Args: - config (RunnableConfig): The config. - - Yields: - Generator[Executor, None, None]: The executor. - """ - with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: - yield executor +__all__ = [ + "EmptyDict", + "RunnableConfig", + "ensure_config", + "get_config_list", + "patch_config", + "merge_configs", + "call_func_with_variable_args", + "get_callback_manager_for_config", + "get_async_callback_manager_for_config", + "get_executor_for_config", +] diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index ffce3412ee4..a1463d57466 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -1,388 +1,15 @@ -from __future__ import annotations - -import enum -import threading -from abc import abstractmethod -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Optional, - Sequence, - Type, - Union, - cast, -) -from weakref import WeakValueDictionary - -from langchain.pydantic_v1 import BaseModel -from langchain.schema.runnable.base import Runnable, RunnableSerializable -from langchain.schema.runnable.config import ( - RunnableConfig, - get_config_list, - get_executor_for_config, -) -from langchain.schema.runnable.utils import ( - AnyConfigurableField, - ConfigurableField, - ConfigurableFieldMultiOption, - ConfigurableFieldSingleOption, - ConfigurableFieldSpec, - Input, - Output, - gather_with_concurrency, - get_unique_config_specs, +from langchain_core.runnables.configurable import ( + DynamicRunnable, + RunnableConfigurableAlternatives, + RunnableConfigurableFields, + StrEnum, + make_options_spec, ) - -class DynamicRunnable(RunnableSerializable[Input, Output]): - """A Serializable Runnable that can be dynamically configured.""" - - default: RunnableSerializable[Input, Output] - - class Config: - arbitrary_types_allowed = True - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - @property - def InputType(self) -> Type[Input]: - return self.default.InputType - - @property - def OutputType(self) -> Type[Output]: - return self.default.OutputType - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self._prepare(config).get_input_schema(config) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self._prepare(config).get_output_schema(config) - - @abstractmethod - def _prepare( - self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: - ... - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - return self._prepare(config).invoke(input, config, **kwargs) - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - return await self._prepare(config).ainvoke(input, config, **kwargs) - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - configs = get_config_list(config, len(inputs)) - prepared = [self._prepare(c) for c in configs] - - if all(p is self.default for p in prepared): - return self.default.batch( - inputs, config, return_exceptions=return_exceptions, **kwargs - ) - - if not inputs: - return [] - - configs = get_config_list(config, len(inputs)) - - def invoke( - bound: Runnable[Input, Output], - input: Input, - config: RunnableConfig, - ) -> Union[Output, Exception]: - if return_exceptions: - try: - return bound.invoke(input, config, **kwargs) - except Exception as e: - return e - else: - return bound.invoke(input, config, **kwargs) - - # If there's only one input, don't bother with the executor - if len(inputs) == 1: - return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])]) - - with get_executor_for_config(configs[0]) as executor: - return cast( - List[Output], list(executor.map(invoke, prepared, inputs, configs)) - ) - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - configs = get_config_list(config, len(inputs)) - prepared = [self._prepare(c) for c in configs] - - if all(p is self.default for p in prepared): - return await self.default.abatch( - inputs, config, return_exceptions=return_exceptions, **kwargs - ) - - if not inputs: - return [] - - configs = get_config_list(config, len(inputs)) - - async def ainvoke( - bound: Runnable[Input, Output], - input: Input, - config: RunnableConfig, - ) -> Union[Output, Exception]: - if return_exceptions: - try: - return await bound.ainvoke(input, config, **kwargs) - except Exception as e: - return e - else: - return await bound.ainvoke(input, config, **kwargs) - - coros = map(ainvoke, prepared, inputs, configs) - return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) - - def stream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - return self._prepare(config).stream(input, config, **kwargs) - - async def astream( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - async for chunk in self._prepare(config).astream(input, config, **kwargs): - yield chunk - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - return self._prepare(config).transform(input, config, **kwargs) - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - async for chunk in self._prepare(config).atransform(input, config, **kwargs): - yield chunk - - -class RunnableConfigurableFields(DynamicRunnable[Input, Output]): - """A Runnable that can be dynamically configured.""" - - fields: Dict[str, AnyConfigurableField] - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - [ - ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description - or self.default.__fields__[field_name].field_info.description, - annotation=spec.annotation - or self.default.__fields__[field_name].annotation, - default=getattr(self.default, field_name), - ) - if isinstance(spec, ConfigurableField) - else make_options_spec( - spec, self.default.__fields__[field_name].field_info.description - ) - for field_name, spec in self.fields.items() - ] - + list(self.default.config_specs) - ) - - def configurable_fields( - self, **kwargs: AnyConfigurableField - ) -> RunnableSerializable[Input, Output]: - return self.default.configurable_fields(**{**self.fields, **kwargs}) - - def _prepare( - self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: - config = config or {} - specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} - configurable_fields = { - specs_by_id[k][0]: v - for k, v in config.get("configurable", {}).items() - if k in specs_by_id and isinstance(specs_by_id[k][1], ConfigurableField) - } - configurable_single_options = { - k: v.options[(config.get("configurable", {}).get(v.id) or v.default)] - for k, v in self.fields.items() - if isinstance(v, ConfigurableFieldSingleOption) - } - configurable_multi_options = { - k: [ - v.options[o] - for o in config.get("configurable", {}).get(v.id, v.default) - ] - for k, v in self.fields.items() - if isinstance(v, ConfigurableFieldMultiOption) - } - configurable = { - **configurable_fields, - **configurable_single_options, - **configurable_multi_options, - } - - if configurable: - return self.default.__class__(**{**self.default.__dict__, **configurable}) - else: - return self.default - - -# Before Python 3.11 native StrEnum is not available -class StrEnum(str, enum.Enum): - """A string enum.""" - - pass - - -_enums_for_spec: WeakValueDictionary[ - Union[ - ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField - ], - Type[StrEnum], -] = WeakValueDictionary() - -_enums_for_spec_lock = threading.Lock() - - -class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): - """A Runnable that can be dynamically configured.""" - - which: ConfigurableField - - alternatives: Dict[ - str, - Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], - ] - - default_key: str = "default" - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - with _enums_for_spec_lock: - if which_enum := _enums_for_spec.get(self.which): - pass - else: - which_enum = StrEnum( # type: ignore[call-overload] - self.which.name or self.which.id, - ( - (v, v) - for v in list(self.alternatives.keys()) + [self.default_key] - ), - ) - _enums_for_spec[self.which] = cast(Type[StrEnum], which_enum) - return [ - ConfigurableFieldSpec( - id=self.which.id, - name=self.which.name, - description=self.which.description, - annotation=which_enum, - default=self.default_key, - ), - *self.default.config_specs, - ] + [ - s - for alt in self.alternatives.values() - if isinstance(alt, RunnableSerializable) - for s in alt.config_specs - ] - - def configurable_fields( - self, **kwargs: AnyConfigurableField - ) -> RunnableSerializable[Input, Output]: - return self.__class__( - which=self.which, - default=self.default.configurable_fields(**kwargs), - alternatives=self.alternatives, - ) - - def _prepare( - self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: - config = config or {} - which = config.get("configurable", {}).get(self.which.id, self.default_key) - if which == self.default_key: - return self.default - elif which in self.alternatives: - alt = self.alternatives[which] - if isinstance(alt, Runnable): - return alt - else: - return alt() - else: - raise ValueError(f"Unknown alternative: {which}") - - -def make_options_spec( - spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption], - description: Optional[str], -) -> ConfigurableFieldSpec: - """Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or - ConfigurableFieldMultiOption.""" - with _enums_for_spec_lock: - if enum := _enums_for_spec.get(spec): - pass - else: - enum = StrEnum( # type: ignore[call-overload] - spec.name or spec.id, - ((v, v) for v in list(spec.options.keys())), - ) - _enums_for_spec[spec] = cast(Type[StrEnum], enum) - if isinstance(spec, ConfigurableFieldSingleOption): - return ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description or description, - annotation=enum, - default=spec.default, - ) - else: - return ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description or description, - annotation=Sequence[enum], # type: ignore[valid-type] - default=spec.default, - ) +__all__ = [ + "DynamicRunnable", + "RunnableConfigurableFields", + "StrEnum", + "RunnableConfigurableAlternatives", + "make_options_spec", +] diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 4f0c6084ef2..7a54468d774 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -1,344 +1,3 @@ -import asyncio -from typing import ( - TYPE_CHECKING, - Any, - Iterator, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) +from langchain_core.runnables.fallbacks import RunnableWithFallbacks -from langchain.load.dump import dumpd -from langchain.pydantic_v1 import BaseModel -from langchain.schema.runnable.base import Runnable, RunnableSerializable -from langchain.schema.runnable.config import ( - RunnableConfig, - ensure_config, - get_async_callback_manager_for_config, - get_callback_manager_for_config, - get_config_list, - patch_config, -) -from langchain.schema.runnable.utils import ( - ConfigurableFieldSpec, - Input, - Output, - get_unique_config_specs, -) - -if TYPE_CHECKING: - from langchain.callbacks.manager import AsyncCallbackManagerForChainRun - - -class RunnableWithFallbacks(RunnableSerializable[Input, Output]): - """A Runnable that can fallback to other Runnables if it fails. - - External APIs (e.g., APIs for a language model) may at times experience - degraded performance or even downtime. - - In these cases, it can be useful to have a fallback runnable that can be - used in place of the original runnable (e.g., fallback to another LLM provider). - - Fallbacks can be defined at the level of a single runnable, or at the level - of a chain of runnables. Fallbacks are tried in order until one succeeds or - all fail. - - While you can instantiate a ``RunnableWithFallbacks`` directly, it is usually - more convenient to use the ``with_fallbacks`` method on a runnable. - - Example: - - .. code-block:: python - - from langchain.chat_models.openai import ChatOpenAI - from langchain.chat_models.anthropic import ChatAnthropic - - model = ChatAnthropic().with_fallbacks([ChatOpenAI()]) - # Will usually use ChatAnthropic, but fallback to ChatOpenAI - # if ChatAnthropic fails. - model.invoke('hello') - - # And you can also use fallbacks at the level of a chain. - # Here if both LLM providers fail, we'll fallback to a good hardcoded - # response. - - from langchain.prompts import PromptTemplate - from langchain.schema.output_parser import StrOutputParser - from langchain.schema.runnable import RunnableLambda - - def when_all_is_lost(inputs): - return ("Looks like our LLM providers are down. " - "Here's a nice ðŸ¦œï¸ emoji for you instead.") - - chain_with_fallback = ( - PromptTemplate.from_template('Tell me a joke about {topic}') - | model - | StrOutputParser() - ).with_fallbacks([RunnableLambda(when_all_is_lost)]) - """ - - runnable: Runnable[Input, Output] - """The runnable to run first.""" - fallbacks: Sequence[Runnable[Input, Output]] - """A sequence of fallbacks to try.""" - exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) - """The exceptions on which fallbacks should be tried. - - Any exception that is not a subclass of these exceptions will be raised immediately. - """ - - class Config: - arbitrary_types_allowed = True - - @property - def InputType(self) -> Type[Input]: - return self.runnable.InputType - - @property - def OutputType(self) -> Type[Output]: - return self.runnable.OutputType - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.runnable.get_input_schema(config) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.runnable.get_output_schema(config) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec - for step in [self.runnable, *self.fallbacks] - for spec in step.config_specs - ) - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - @property - def runnables(self) -> Iterator[Runnable[Input, Output]]: - yield self.runnable - yield from self.fallbacks - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - # start the root run - run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - first_error = None - for runnable in self.runnables: - try: - output = runnable.invoke( - input, - patch_config(config, callbacks=run_manager.get_child()), - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - run_manager.on_chain_error(e) - raise e - else: - run_manager.on_chain_end(output) - return output - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - run_manager.on_chain_error(first_error) - raise first_error - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - # start the root run - run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - first_error = None - for runnable in self.runnables: - try: - output = await runnable.ainvoke( - input, - patch_config(config, callbacks=run_manager.get_child()), - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - await run_manager.on_chain_error(e) - raise e - else: - await run_manager.on_chain_end(output) - return output - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - await run_manager.on_chain_error(first_error) - raise first_error - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import CallbackManager - - if return_exceptions: - raise NotImplementedError() - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers = [ - cm.on_chain_start( - dumpd(self), - input if isinstance(input, dict) else {"input": input}, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ] - - first_error = None - for runnable in self.runnables: - try: - outputs = runnable.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - for rm in run_managers: - rm.on_chain_error(e) - raise e - else: - for rm, output in zip(run_managers, outputs): - rm.on_chain_end(output) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - for rm in run_managers: - rm.on_chain_error(first_error) - raise first_error - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - - if return_exceptions: - raise NotImplementedError() - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( - *( - cm.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ) - ) - - first_error = None - for runnable in self.runnables: - try: - outputs = await runnable.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - else: - await asyncio.gather( - *( - rm.on_chain_end(output) - for rm, output in zip(run_managers, outputs) - ) - ) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) - raise first_error +__all__ = ["RunnableWithFallbacks"] diff --git a/libs/langchain/langchain/schema/runnable/history.py b/libs/langchain/langchain/schema/runnable/history.py index 701548b2a99..a7cba8f299c 100644 --- a/libs/langchain/langchain/schema/runnable/history.py +++ b/libs/langchain/langchain/schema/runnable/history.py @@ -1,288 +1,3 @@ -from __future__ import annotations +from langchain_core.runnables.history import RunnableWithMessageHistory -import asyncio -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Type, - Union, -) - -from langchain.load import load -from langchain.pydantic_v1 import BaseModel, create_model -from langchain.schema.chat_history import BaseChatMessageHistory -from langchain.schema.runnable.base import Runnable, RunnableBindingBase, RunnableLambda -from langchain.schema.runnable.passthrough import RunnablePassthrough -from langchain.schema.runnable.utils import ( - ConfigurableFieldSpec, - get_unique_config_specs, -) - -if TYPE_CHECKING: - from langchain.callbacks.tracers.schemas import Run - from langchain.schema.messages import BaseMessage - from langchain.schema.runnable.config import RunnableConfig - -MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] -GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] - - -class RunnableWithMessageHistory(RunnableBindingBase): - """A runnable that manages chat message history for another runnable. - - Base runnable must have inputs and outputs that can be converted to a list of - BaseMessages. - - RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.: - ``{"configurable": {"session_id": ""}}`` - - Example (dict input): - .. code-block:: python - - from typing import Optional - - from langchain.chat_models import ChatAnthropic - from langchain.memory.chat_message_histories import RedisChatMessageHistory - from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder - from langchain.schema.runnable.history import RunnableWithMessageHistory - - - prompt = ChatPromptTemplate.from_messages([ - ("system", "You're an assistant who's good at {ability}"), - MessagesPlaceholder(variable_name="history"), - ("human", "{question}"), - ]) - - chain = prompt | ChatAnthropic(model="claude-2") - - chain_with_history = RunnableWithMessageHistory( - chain, - RedisChatMessageHistory, - input_messages_key="question", - history_messages_key="history", - ) - - chain_with_history.invoke( - {"ability": "math", "question": "What does cosine mean?"}, - config={"configurable": {"session_id": "foo"}} - ) - # -> "Cosine is ..." - chain_with_history.invoke( - {"ability": "math", "question": "What's its inverse"}, - config={"configurable": {"session_id": "foo"}} - ) - # -> "The inverse of cosine is called arccosine ..." - - """ # noqa: E501 - - get_session_history: GetSessionHistoryCallable - input_messages_key: Optional[str] = None - output_messages_key: Optional[str] = None - history_messages_key: Optional[str] = None - - def __init__( - self, - runnable: Runnable[ - MessagesOrDictWithMessages, - Union[str, BaseMessage, MessagesOrDictWithMessages], - ], - get_session_history: GetSessionHistoryCallable, - *, - input_messages_key: Optional[str] = None, - output_messages_key: Optional[str] = None, - history_messages_key: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize RunnableWithMessageHistory. - - Args: - runnable: The base Runnable to be wrapped. - - Must take as input one of: - - A sequence of BaseMessages - - A dict with one key for all messages - - A dict with one key for the current input string/message(s) and - a separate key for historical messages. If the input key points - to a string, it will be treated as a HumanMessage in history. - - Must return as output one of: - - A string which can be treated as an AIMessage - - A BaseMessage or sequence of BaseMessages - - A dict with a key for a BaseMessage or sequence of BaseMessages - - get_session_history: Function that returns a new BaseChatMessageHistory - given a session id. Should take a single - positional argument `session_id` which is a string and a named argument - `user_id` which can be a string or None. e.g.: - - ```python - def get_session_history( - session_id: str, - *, - user_id: Optional[str]=None - ) -> BaseChatMessageHistory: - ... - ``` - - input_messages_key: Must be specified if the base runnable accepts a dict - as input. - output_messages_key: Must be specified if the base runnable returns a dict - as output. - history_messages_key: Must be specified if the base runnable accepts a dict - as input and expects a separate key for historical messages. - **kwargs: Arbitrary additional kwargs to pass to parent class - ``RunnableBindingBase`` init. - """ # noqa: E501 - history_chain: Runnable = RunnableLambda( - self._enter_history, self._aenter_history - ).with_config(run_name="load_history") - messages_key = history_messages_key or input_messages_key - if messages_key: - history_chain = RunnablePassthrough.assign( - **{messages_key: history_chain} - ).with_config(run_name="insert_history") - bound = ( - history_chain | runnable.with_listeners(on_end=self._exit_history) - ).with_config(run_name="RunnableWithMessageHistory") - super().__init__( - get_session_history=get_session_history, - input_messages_key=input_messages_key, - output_messages_key=output_messages_key, - bound=bound, - history_messages_key=history_messages_key, - **kwargs, - ) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - super().config_specs - + [ - ConfigurableFieldSpec( - id="session_id", - annotation=str, - name="Session ID", - description="Unique identifier for a session.", - default="", - ), - ] - ) - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - super_schema = super().get_input_schema(config) - if super_schema.__custom_root_type__ is not None: - from langchain.schema.messages import BaseMessage - - fields: Dict = {} - if self.input_messages_key and self.history_messages_key: - fields[self.input_messages_key] = ( - Union[str, BaseMessage, Sequence[BaseMessage]], - ..., - ) - elif self.input_messages_key: - fields[self.input_messages_key] = (Sequence[BaseMessage], ...) - else: - fields["__root__"] = (Sequence[BaseMessage], ...) - if self.history_messages_key: - fields[self.history_messages_key] = (Sequence[BaseMessage], ...) - return create_model( # type: ignore[call-overload] - "RunnableWithChatHistoryInput", - **fields, - ) - else: - return super_schema - - def _get_input_messages( - self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]] - ) -> List[BaseMessage]: - from langchain.schema.messages import BaseMessage - - if isinstance(input_val, str): - from langchain.schema.messages import HumanMessage - - return [HumanMessage(content=input_val)] - elif isinstance(input_val, BaseMessage): - return [input_val] - elif isinstance(input_val, (list, tuple)): - return list(input_val) - else: - raise ValueError( - f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " - f"Got {input_val}." - ) - - def _get_output_messages( - self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] - ) -> List[BaseMessage]: - from langchain.schema.messages import BaseMessage - - if isinstance(output_val, dict): - output_val = output_val[self.output_messages_key or "output"] - - if isinstance(output_val, str): - from langchain.schema.messages import AIMessage - - return [AIMessage(content=output_val)] - elif isinstance(output_val, BaseMessage): - return [output_val] - elif isinstance(output_val, (list, tuple)): - return list(output_val) - else: - raise ValueError() - - def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: - hist = config["configurable"]["message_history"] - # return only historic messages - if self.history_messages_key: - return hist.messages.copy() - # return all messages - else: - input_val = ( - input if not self.input_messages_key else input[self.input_messages_key] - ) - return hist.messages.copy() + self._get_input_messages(input_val) - - async def _aenter_history( - self, input: Dict[str, Any], config: RunnableConfig - ) -> List[BaseMessage]: - return await asyncio.get_running_loop().run_in_executor( - None, self._enter_history, input, config - ) - - def _exit_history(self, run: Run, config: RunnableConfig) -> None: - hist = config["configurable"]["message_history"] - - # Get the input messages - inputs = load(run.inputs) - input_val = inputs[self.input_messages_key or "input"] - input_messages = self._get_input_messages(input_val) - - # Get the output messages - output_val = load(run.outputs) - output_messages = self._get_output_messages(output_val) - - for m in input_messages + output_messages: - hist.add_message(m) - - def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: - config = super()._merge_configs(*configs) - # extract session_id - if "session_id" not in config.get("configurable", {}): - example_input = {self.input_messages_key: "foo"} - example_config = {"configurable": {"session_id": "123"}} - raise ValueError( - "session_id_id is required." - " Pass it in as part of the config argument to .invoke() or .stream()" - f"\neg. chain.invoke({example_input}, {example_config})" - ) - # attach message_history - session_id = config["configurable"]["session_id"] - config["configurable"]["message_history"] = self.get_session_history(session_id) - return config +__all__ = ["RunnableWithMessageHistory"] diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index fef81e96706..81141a7d363 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -1,453 +1,7 @@ -"""Implementation of the RunnablePassthrough.""" -from __future__ import annotations - -import asyncio -import inspect -import threading -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Type, - Union, - cast, +from langchain_core.runnables.passthrough import ( + RunnableAssign, + RunnablePassthrough, + identity, ) -from langchain.pydantic_v1 import BaseModel, create_model -from langchain.schema.runnable.base import ( - Other, - Runnable, - RunnableParallel, - RunnableSerializable, -) -from langchain.schema.runnable.config import ( - RunnableConfig, - acall_func_with_variable_args, - call_func_with_variable_args, - get_executor_for_config, -) -from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec -from langchain.utils.aiter import atee, py_anext -from langchain.utils.iter import safetee - - -def identity(x: Other) -> Other: - """An identity function""" - return x - - -async def aidentity(x: Other) -> Other: - """An async identity function""" - return x - - -class RunnablePassthrough(RunnableSerializable[Other, Other]): - """A runnable to passthrough inputs unchanged or with additional keys. - - This runnable behaves almost like the identity function, except that it - can be configured to add additional keys to the output, if the input is a - dict. - - The examples below demonstrate this runnable works using a few simple - chains. The chains rely on simple lambdas to make the examples easy to execute - and experiment with. - - Examples: - - .. code-block:: python - - from langchain.schema.runnable import RunnablePassthrough, RunnableParallel - - runnable = RunnableParallel( - origin=RunnablePassthrough(), - modified=lambda x: x+1 - ) - - runnable.invoke(1) # {'origin': 1, 'modified': 2} - - - def fake_llm(prompt: str) -> str: # Fake LLM for the example - return "completion" - - chain = RunnableLambda(fake_llm) | { - 'original': RunnablePassthrough(), # Original LLM output - 'parsed': lambda text: text[::-1] # Parsing logic - } - - chain.invoke('hello') # {'original': 'completion', 'parsed': 'noitelpmoc'} - - In some cases, it may be useful to pass the input through while adding some - keys to the output. In this case, you can use the `assign` method: - - .. code-block:: python - - from langchain.schema.runnable import RunnablePassthrough, RunnableParallel - - def fake_llm(prompt: str) -> str: # Fake LLM for the example - return "completion" - - runnable = { - 'llm1': fake_llm, - 'llm2': fake_llm, - } - | RunnablePassthrough.assign( - total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2']) - ) - - runnable.invoke('hello') - # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} - """ - - input_type: Optional[Type[Other]] = None - - func: Optional[ - Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] - ] = None - - afunc: Optional[ - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ] - ] = None - - def __init__( - self, - func: Optional[ - Union[ - Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]], - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ], - ] - ] = None, - afunc: Optional[ - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ] - ] = None, - *, - input_type: Optional[Type[Other]] = None, - **kwargs: Any, - ) -> None: - if inspect.iscoroutinefunction(func): - afunc = func - func = None - - super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - @property - def InputType(self) -> Any: - return self.input_type or Any - - @property - def OutputType(self) -> Any: - return self.input_type or Any - - @classmethod - def assign( - cls, - **kwargs: Union[ - Runnable[Dict[str, Any], Any], - Callable[[Dict[str, Any]], Any], - Mapping[ - str, - Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], - ], - ], - ) -> RunnableAssign: - """Merge the Dict input with the output produced by the mapping argument. - - Args: - mapping: A mapping from keys to runnables or callables. - - Returns: - A runnable that merges the Dict input with the output produced by the - mapping argument. - """ - return RunnableAssign(RunnableParallel(kwargs)) - - def invoke( - self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Other: - if self.func is not None: - call_func_with_variable_args(self.func, input, config or {}, **kwargs) - return self._call_with_config(identity, input, config) - - async def ainvoke( - self, - input: Other, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Other: - if self.afunc is not None: - await acall_func_with_variable_args( - self.afunc, input, config or {}, **kwargs - ) - elif self.func is not None: - call_func_with_variable_args(self.func, input, config or {}, **kwargs) - return await self._acall_with_config(aidentity, input, config) - - def transform( - self, - input: Iterator[Other], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Other]: - if self.func is None: - for chunk in self._transform_stream_with_config(input, identity, config): - yield chunk - else: - final = None - - for chunk in self._transform_stream_with_config(input, identity, config): - yield chunk - if final is None: - final = chunk - else: - final = final + chunk - - if final is not None: - call_func_with_variable_args(self.func, final, config or {}, **kwargs) - - async def atransform( - self, - input: AsyncIterator[Other], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Other]: - if self.afunc is None and self.func is None: - async for chunk in self._atransform_stream_with_config( - input, identity, config - ): - yield chunk - else: - final = None - - async for chunk in self._atransform_stream_with_config( - input, identity, config - ): - yield chunk - if final is None: - final = chunk - else: - final = final + chunk - - if final is not None: - config = config or {} - if self.afunc is not None: - await acall_func_with_variable_args( - self.afunc, final, config, **kwargs - ) - elif self.func is not None: - call_func_with_variable_args(self.func, final, config, **kwargs) - - def stream( - self, - input: Other, - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Other]: - return self.transform(iter([input]), config, **kwargs) - - async def astream( - self, - input: Other, - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Other]: - async def input_aiter() -> AsyncIterator[Other]: - yield input - - async for chunk in self.atransform(input_aiter(), config, **kwargs): - yield chunk - - -class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): - """ - A runnable that assigns key-value pairs to Dict[str, Any] inputs. - """ - - mapper: RunnableParallel[Dict[str, Any]] - - def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None: - super().__init__(mapper=mapper, **kwargs) - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - map_input_schema = self.mapper.get_input_schema(config) - if not map_input_schema.__custom_root_type__: - # ie. it's a dict - return map_input_schema - - return super().get_input_schema(config) - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - map_input_schema = self.mapper.get_input_schema(config) - map_output_schema = self.mapper.get_output_schema(config) - if ( - not map_input_schema.__custom_root_type__ - and not map_output_schema.__custom_root_type__ - ): - # ie. both are dicts - return create_model( # type: ignore[call-overload] - "RunnableAssignOutput", - **{ - k: (v.type_, v.default) - for s in (map_input_schema, map_output_schema) - for k, v in s.__fields__.items() - }, - ) - elif not map_output_schema.__custom_root_type__: - # ie. only map output is a dict - # ie. input type is either unknown or inferred incorrectly - return map_output_schema - - return super().get_output_schema(config) - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return self.mapper.config_specs - - def invoke( - self, - input: Dict[str, Any], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Dict[str, Any]: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." - return { - **input, - **self.mapper.invoke(input, config, **kwargs), - } - - async def ainvoke( - self, - input: Dict[str, Any], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Dict[str, Any]: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." - return { - **input, - **await self.mapper.ainvoke(input, config, **kwargs), - } - - def transform( - self, - input: Iterator[Dict[str, Any]], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: - # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) - # create two streams, one for the map and one for the passthrough - for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) - # create map output stream - map_output = self.mapper.transform(for_map, config, **kwargs) - # get executor to start map output stream in background - with get_executor_for_config(config or {}) as executor: - # start map output stream - first_map_chunk_future = executor.submit( - next, - map_output, # type: ignore - None, - ) - # consume passthrough stream - for chunk in for_passthrough: - assert isinstance( - chunk, dict - ), "The input to RunnablePassthrough.assign() must be a dict." - # remove mapper keys from passthrough chunk, to be overwritten by map - filtered = AddableDict( - {k: v for k, v in chunk.items() if k not in mapper_keys} - ) - if filtered: - yield filtered - # yield map output - yield cast(Dict[str, Any], first_map_chunk_future.result()) - for chunk in map_output: - yield chunk - - async def atransform( - self, - input: AsyncIterator[Dict[str, Any]], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: - # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) - # create two streams, one for the map and one for the passthrough - for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) - # create map output stream - map_output = self.mapper.atransform(for_map, config, **kwargs) - # start map output stream - first_map_chunk_task: asyncio.Task = asyncio.create_task( - py_anext(map_output, None), # type: ignore[arg-type] - ) - # consume passthrough stream - async for chunk in for_passthrough: - assert isinstance( - chunk, dict - ), "The input to RunnablePassthrough.assign() must be a dict." - # remove mapper keys from passthrough chunk, to be overwritten by map output - filtered = AddableDict( - {k: v for k, v in chunk.items() if k not in mapper_keys} - ) - if filtered: - yield filtered - # yield map output - yield await first_map_chunk_task - async for chunk in map_output: - yield chunk - - def stream( - self, - input: Dict[str, Any], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: - return self.transform(iter([input]), config, **kwargs) - - async def astream( - self, - input: Dict[str, Any], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: - async def input_aiter() -> AsyncIterator[Dict[str, Any]]: - yield input - - async for chunk in self.atransform(input_aiter(), config, **kwargs): - yield chunk +__all__ = ["identity", "RunnablePassthrough", "RunnableAssign"] diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index 99b665bf1ee..4e1f4dbcc2e 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,337 +1,3 @@ -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) +from langchain_core.runnables.retry import RunnableRetry -from tenacity import ( - AsyncRetrying, - RetryCallState, - RetryError, - Retrying, - retry_if_exception_type, - stop_after_attempt, - wait_exponential_jitter, -) - -from langchain.schema.runnable.base import Input, Output, RunnableBindingBase -from langchain.schema.runnable.config import RunnableConfig, patch_config - -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, - ) - - T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) -U = TypeVar("U") - - -class RunnableRetry(RunnableBindingBase[Input, Output]): - """Retry a Runnable if it fails. - - A RunnableRetry helps can be used to add retry logic to any object - that subclasses the base Runnable. - - Such retries are especially useful for network calls that may fail - due to transient errors. - - The RunnableRetry is implemented as a RunnableBinding. The easiest - way to use it is through the `.with_retry()` method on all Runnables. - - Example: - - Here's an example that uses a RunnableLambda to raise an exception - - .. code-block:: python - - import time - - def foo(input) -> None: - '''Fake function that raises an exception.''' - raise ValueError("Invoking foo failed. At time {time.time()}") - - runnable = RunnableLambda(foo) - - runnable_with_retries = runnable.with_retry( - retry_exception_types=(ValueError,), # Retry only on ValueError - wait_exponential_jitter=True, # Add jitter to the exponential backoff - max_attempt_number=2, # Try twice - ) - - # The method invocation above is equivalent to the longer form below: - - runnable_with_retries = RunnableRetry( - bound=runnable, - retry_exception_types=(ValueError,), - max_attempt_number=2, - wait_exponential_jitter=True - ) - - This logic can be used to retry any Runnable, including a chain of Runnables, - but in general it's best practice to keep the scope of the retry as small as - possible. For example, if you have a chain of Runnables, you should only retry - the Runnable that is likely to fail, not the entire chain. - - Example: - - .. code-block:: python - - from langchain.chat_models import ChatOpenAI - from langchain.prompts import PromptTemplate - - template = PromptTemplate.from_template("tell me a joke about {topic}.") - model = ChatOpenAI(temperature=0.5) - - # Good - chain = template | model.with_retry() - - # Bad - chain = template | model - retryable_chain = chain.with_retry() - """ - - retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) - """The exception types to retry on. By default all exceptions are retried. - - In general you should only retry on exceptions that are likely to be - transient, such as network errors. - - Good exceptions to retry are all server errors (5xx) and selected client - errors (4xx) such as 429 Too Many Requests. - """ - - wait_exponential_jitter: bool = True - """Whether to add jitter to the exponential backoff.""" - - max_attempt_number: int = 3 - """The maximum number of attempts to retry the runnable.""" - - @property - def _kwargs_retrying(self) -> Dict[str, Any]: - kwargs: Dict[str, Any] = dict() - - if self.max_attempt_number: - kwargs["stop"] = stop_after_attempt(self.max_attempt_number) - - if self.wait_exponential_jitter: - kwargs["wait"] = wait_exponential_jitter() - - if self.retry_exception_types: - kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) - - return kwargs - - def _sync_retrying(self, **kwargs: Any) -> Retrying: - return Retrying(**self._kwargs_retrying, **kwargs) - - def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: - return AsyncRetrying(**self._kwargs_retrying, **kwargs) - - def _patch_config( - self, - config: RunnableConfig, - run_manager: "T", - retry_state: RetryCallState, - ) -> RunnableConfig: - attempt = retry_state.attempt_number - tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None - return patch_config(config, callbacks=run_manager.get_child(tag)) - - def _patch_config_list( - self, - config: List[RunnableConfig], - run_manager: List["T"], - retry_state: RetryCallState, - ) -> List[RunnableConfig]: - return [ - self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) - ] - - def _invoke( - self, - input: Input, - run_manager: "CallbackManagerForChainRun", - config: RunnableConfig, - **kwargs: Any, - ) -> Output: - for attempt in self._sync_retrying(reraise=True): - with attempt: - result = super().invoke( - input, - self._patch_config(config, run_manager, attempt.retry_state), - **kwargs, - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - return self._call_with_config(self._invoke, input, config, **kwargs) - - async def _ainvoke( - self, - input: Input, - run_manager: "AsyncCallbackManagerForChainRun", - config: RunnableConfig, - **kwargs: Any, - ) -> Output: - async for attempt in self._async_retrying(reraise=True): - with attempt: - result = await super().ainvoke( - input, - self._patch_config(config, run_manager, attempt.retry_state), - **kwargs, - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - return await self._acall_with_config(self._ainvoke, input, config, **kwargs) - - def _batch( - self, - inputs: List[Input], - run_manager: List["CallbackManagerForChainRun"], - config: List[RunnableConfig], - **kwargs: Any, - ) -> List[Union[Output, Exception]]: - results_map: Dict[int, Output] = {} - - def pending(iterable: List[U]) -> List[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - - try: - for attempt in self._sync_retrying(): - with attempt: - # Get the results of the inputs that have not succeeded yet. - result = super().batch( - pending(inputs), - self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state - ), - return_exceptions=True, - **kwargs, - ) - # Register the results of the inputs that have succeeded. - first_exception = None - for i, r in enumerate(result): - if isinstance(r, Exception): - if not first_exception: - first_exception = r - continue - results_map[i] = r - # If any exception occurred, raise it, to retry the failed ones - if first_exception: - raise first_exception - if ( - attempt.retry_state.outcome - and not attempt.retry_state.outcome.failed - ): - attempt.retry_state.set_result(result) - except RetryError as e: - try: - result - except UnboundLocalError: - result = cast(List[Output], [e] * len(inputs)) - - outputs: List[Union[Output, Exception]] = [] - for idx, _ in enumerate(inputs): - if idx in results_map: - outputs.append(results_map[idx]) - else: - outputs.append(result.pop(0)) - return outputs - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Any, - ) -> List[Output]: - return self._batch_with_config( - self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs - ) - - async def _abatch( - self, - inputs: List[Input], - run_manager: List["AsyncCallbackManagerForChainRun"], - config: List[RunnableConfig], - **kwargs: Any, - ) -> List[Union[Output, Exception]]: - results_map: Dict[int, Output] = {} - - def pending(iterable: List[U]) -> List[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - - try: - async for attempt in self._async_retrying(): - with attempt: - # Get the results of the inputs that have not succeeded yet. - result = await super().abatch( - pending(inputs), - self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state - ), - return_exceptions=True, - **kwargs, - ) - # Register the results of the inputs that have succeeded. - first_exception = None - for i, r in enumerate(result): - if isinstance(r, Exception): - if not first_exception: - first_exception = r - continue - results_map[i] = r - # If any exception occurred, raise it, to retry the failed ones - if first_exception: - raise first_exception - if ( - attempt.retry_state.outcome - and not attempt.retry_state.outcome.failed - ): - attempt.retry_state.set_result(result) - except RetryError as e: - try: - result - except UnboundLocalError: - result = cast(List[Output], [e] * len(inputs)) - - outputs: List[Union[Output, Exception]] = [] - for idx, _ in enumerate(inputs): - if idx in results_map: - outputs.append(results_map[idx]) - else: - outputs.append(result.pop(0)) - return outputs - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Any, - ) -> List[Output]: - return await self._abatch_with_config( - self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs - ) - - # stream() and transform() are not retried because retrying a stream - # is not very intuitive. +__all__ = ["RunnableRetry"] diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index e0fea406df5..259dd677c98 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -1,206 +1,3 @@ -from __future__ import annotations +from langchain_core.runnables.router import RouterInput, RouterRunnable -from typing import ( - Any, - AsyncIterator, - Callable, - Iterator, - List, - Mapping, - Optional, - Union, - cast, -) - -from typing_extensions import TypedDict - -from langchain.schema.runnable.base import ( - Input, - Output, - Runnable, - RunnableSerializable, - coerce_to_runnable, -) -from langchain.schema.runnable.config import ( - RunnableConfig, - get_config_list, - get_executor_for_config, -) -from langchain.schema.runnable.utils import ( - ConfigurableFieldSpec, - gather_with_concurrency, - get_unique_config_specs, -) - - -class RouterInput(TypedDict): - """A Router input. - - Attributes: - key: The key to route on. - input: The input to pass to the selected runnable. - """ - - key: str - input: Any - - -class RouterRunnable(RunnableSerializable[RouterInput, Output]): - """ - A runnable that routes to a set of runnables based on Input['key']. - Returns the output of the selected runnable. - """ - - runnables: Mapping[str, Runnable[Any, Output]] - - @property - def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec for step in self.runnables.values() for spec in step.config_specs - ) - - def __init__( - self, - runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]], - ) -> None: - super().__init__( - runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} - ) - - class Config: - arbitrary_types_allowed = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - def invoke( - self, input: RouterInput, config: Optional[RunnableConfig] = None - ) -> Output: - key = input["key"] - actual_input = input["input"] - if key not in self.runnables: - raise ValueError(f"No runnable associated with key '{key}'") - - runnable = self.runnables[key] - return runnable.invoke(actual_input, config) - - async def ainvoke( - self, - input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - key = input["key"] - actual_input = input["input"] - if key not in self.runnables: - raise ValueError(f"No runnable associated with key '{key}'") - - runnable = self.runnables[key] - return await runnable.ainvoke(actual_input, config) - - def batch( - self, - inputs: List[RouterInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - if not inputs: - return [] - - keys = [input["key"] for input in inputs] - actual_inputs = [input["input"] for input in inputs] - if any(key not in self.runnables for key in keys): - raise ValueError("One or more keys do not have a corresponding runnable") - - def invoke( - runnable: Runnable, input: Input, config: RunnableConfig - ) -> Union[Output, Exception]: - if return_exceptions: - try: - return runnable.invoke(input, config, **kwargs) - except Exception as e: - return e - else: - return runnable.invoke(input, config, **kwargs) - - runnables = [self.runnables[key] for key in keys] - configs = get_config_list(config, len(inputs)) - with get_executor_for_config(configs[0]) as executor: - return cast( - List[Output], - list(executor.map(invoke, runnables, actual_inputs, configs)), - ) - - async def abatch( - self, - inputs: List[RouterInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - if not inputs: - return [] - - keys = [input["key"] for input in inputs] - actual_inputs = [input["input"] for input in inputs] - if any(key not in self.runnables for key in keys): - raise ValueError("One or more keys do not have a corresponding runnable") - - async def ainvoke( - runnable: Runnable, input: Input, config: RunnableConfig - ) -> Union[Output, Exception]: - if return_exceptions: - try: - return await runnable.ainvoke(input, config, **kwargs) - except Exception as e: - return e - else: - return await runnable.ainvoke(input, config, **kwargs) - - runnables = [self.runnables[key] for key in keys] - configs = get_config_list(config, len(inputs)) - return await gather_with_concurrency( - configs[0].get("max_concurrency"), - *( - ainvoke(runnable, input, config) - for runnable, input, config in zip(runnables, actual_inputs, configs) - ), - ) - - def stream( - self, - input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Output]: - key = input["key"] - actual_input = input["input"] - if key not in self.runnables: - raise ValueError(f"No runnable associated with key '{key}'") - - runnable = self.runnables[key] - yield from runnable.stream(actual_input, config) - - async def astream( - self, - input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: - key = input["key"] - actual_input = input["input"] - if key not in self.runnables: - raise ValueError(f"No runnable associated with key '{key}'") - - runnable = self.runnables[key] - async for output in runnable.astream(actual_input, config): - yield output +__all__ = ["RouterInput", "RouterRunnable"] diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index aafd9d59458..020d23f5469 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -1,327 +1,37 @@ -from __future__ import annotations - -import ast -import asyncio -import inspect -import textwrap -from inspect import signature -from itertools import groupby -from typing import ( - Any, - AsyncIterable, - Callable, - Coroutine, - Dict, - Iterable, - List, - Mapping, - NamedTuple, - Optional, - Protocol, - Sequence, - Set, - TypeVar, - Union, +from langchain_core.runnables.utils import ( + AddableDict, + ConfigurableField, + ConfigurableFieldMultiOption, + ConfigurableFieldSingleOption, + ConfigurableFieldSpec, + GetLambdaSource, + IsFunctionArgDict, + IsLocalDict, + SupportsAdd, + accepts_config, + accepts_run_manager, + add, + get_function_first_arg_dict_keys, + get_lambda_source, + get_unique_config_specs, + indent_lines_after_first, ) -Input = TypeVar("Input", contravariant=True) -# Output type should implement __concat__, as eg str, list, dict do -Output = TypeVar("Output", covariant=True) - - -async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: - """Run a coroutine with a semaphore. - Args: - semaphore: The semaphore to use. - coro: The coroutine to run. - - Returns: - The result of the coroutine. - """ - async with semaphore: - return await coro - - -async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: - """Gather coroutines with a limit on the number of concurrent coroutines.""" - if n is None: - return await asyncio.gather(*coros) - - semaphore = asyncio.Semaphore(n) - - return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros)) - - -def accepts_run_manager(callable: Callable[..., Any]) -> bool: - """Check if a callable accepts a run_manager argument.""" - try: - return signature(callable).parameters.get("run_manager") is not None - except ValueError: - return False - - -def accepts_config(callable: Callable[..., Any]) -> bool: - """Check if a callable accepts a config argument.""" - try: - return signature(callable).parameters.get("config") is not None - except ValueError: - return False - - -class IsLocalDict(ast.NodeVisitor): - """Check if a name is a local dict.""" - - def __init__(self, name: str, keys: Set[str]) -> None: - self.name = name - self.keys = keys - - def visit_Subscript(self, node: ast.Subscript) -> Any: - if ( - isinstance(node.ctx, ast.Load) - and isinstance(node.value, ast.Name) - and node.value.id == self.name - and isinstance(node.slice, ast.Constant) - and isinstance(node.slice.value, str) - ): - # we've found a subscript access on the name we're looking for - self.keys.add(node.slice.value) - - def visit_Call(self, node: ast.Call) -> Any: - if ( - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == self.name - and node.func.attr == "get" - and len(node.args) in (1, 2) - and isinstance(node.args[0], ast.Constant) - and isinstance(node.args[0].value, str) - ): - # we've found a .get() call on the name we're looking for - self.keys.add(node.args[0].value) - - -class IsFunctionArgDict(ast.NodeVisitor): - """Check if the first argument of a function is a dict.""" - - def __init__(self) -> None: - self.keys: Set[str] = set() - - def visit_Lambda(self, node: ast.Lambda) -> Any: - input_arg_name = node.args.args[0].arg - IsLocalDict(input_arg_name, self.keys).visit(node.body) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - input_arg_name = node.args.args[0].arg - IsLocalDict(input_arg_name, self.keys).visit(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: - input_arg_name = node.args.args[0].arg - IsLocalDict(input_arg_name, self.keys).visit(node) - - -class GetLambdaSource(ast.NodeVisitor): - """Get the source code of a lambda function.""" - - def __init__(self) -> None: - """Initialize the visitor.""" - self.source: Optional[str] = None - self.count = 0 - - def visit_Lambda(self, node: ast.Lambda) -> Any: - """Visit a lambda function.""" - self.count += 1 - if hasattr(ast, "unparse"): - self.source = ast.unparse(node) - - -def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: - """Get the keys of the first argument of a function if it is a dict.""" - try: - code = inspect.getsource(func) - tree = ast.parse(textwrap.dedent(code)) - visitor = IsFunctionArgDict() - visitor.visit(tree) - return list(visitor.keys) if visitor.keys else None - except (SyntaxError, TypeError, OSError): - return None - - -def get_lambda_source(func: Callable) -> Optional[str]: - """Get the source code of a lambda function. - - Args: - func: a callable that can be a lambda function - - Returns: - str: the source code of the lambda function - """ - try: - code = inspect.getsource(func) - tree = ast.parse(textwrap.dedent(code)) - visitor = GetLambdaSource() - visitor.visit(tree) - return visitor.source if visitor.count == 1 else None - except (SyntaxError, TypeError, OSError): - return None - - -def indent_lines_after_first(text: str, prefix: str) -> str: - """Indent all lines of text after the first line. - - Args: - text: The text to indent - prefix: Used to determine the number of spaces to indent - - Returns: - str: The indented text - """ - n_spaces = len(prefix) - spaces = " " * n_spaces - lines = text.splitlines() - return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) - - -class AddableDict(Dict[str, Any]): - """ - Dictionary that can be added to another dictionary. - """ - - def __add__(self, other: AddableDict) -> AddableDict: - chunk = AddableDict(self) - for key in other: - if key not in chunk or chunk[key] is None: - chunk[key] = other[key] - elif other[key] is not None: - try: - added = chunk[key] + other[key] - except TypeError: - added = other[key] - chunk[key] = added - return chunk - - def __radd__(self, other: AddableDict) -> AddableDict: - chunk = AddableDict(other) - for key in self: - if key not in chunk or chunk[key] is None: - chunk[key] = self[key] - elif self[key] is not None: - try: - added = chunk[key] + self[key] - except TypeError: - added = self[key] - chunk[key] = added - return chunk - - -_T_co = TypeVar("_T_co", covariant=True) -_T_contra = TypeVar("_T_contra", contravariant=True) - - -class SupportsAdd(Protocol[_T_contra, _T_co]): - """Protocol for objects that support addition.""" - - def __add__(self, __x: _T_contra) -> _T_co: - ... - - -Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any]) - - -def add(addables: Iterable[Addable]) -> Optional[Addable]: - """Add a sequence of addable objects together.""" - final = None - for chunk in addables: - if final is None: - final = chunk - else: - final = final + chunk - return final - - -async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: - """Asynchronously add a sequence of addable objects together.""" - final = None - async for chunk in addables: - if final is None: - final = chunk - else: - final = final + chunk - return final - - -class ConfigurableField(NamedTuple): - """A field that can be configured by the user.""" - - id: str - - name: Optional[str] = None - description: Optional[str] = None - annotation: Optional[Any] = None - - def __hash__(self) -> int: - return hash((self.id, self.annotation)) - - -class ConfigurableFieldSingleOption(NamedTuple): - """A field that can be configured by the user with a default value.""" - - id: str - options: Mapping[str, Any] - default: str - - name: Optional[str] = None - description: Optional[str] = None - - def __hash__(self) -> int: - return hash((self.id, tuple(self.options.keys()), self.default)) - - -class ConfigurableFieldMultiOption(NamedTuple): - """A field that can be configured by the user with multiple default values.""" - - id: str - options: Mapping[str, Any] - default: Sequence[str] - - name: Optional[str] = None - description: Optional[str] = None - - def __hash__(self) -> int: - return hash((self.id, tuple(self.options.keys()), tuple(self.default))) - - -AnyConfigurableField = Union[ - ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption +__all__ = [ + "accepts_run_manager", + "accepts_config", + "IsLocalDict", + "IsFunctionArgDict", + "GetLambdaSource", + "get_function_first_arg_dict_keys", + "get_lambda_source", + "indent_lines_after_first", + "AddableDict", + "SupportsAdd", + "add", + "ConfigurableField", + "ConfigurableFieldSingleOption", + "ConfigurableFieldMultiOption", + "ConfigurableFieldSpec", + "get_unique_config_specs", ] - - -class ConfigurableFieldSpec(NamedTuple): - """A field that can be configured by the user. It is a specification of a field.""" - - id: str - name: Optional[str] - description: Optional[str] - - default: Any - annotation: Any - - -def get_unique_config_specs( - specs: Iterable[ConfigurableFieldSpec], -) -> List[ConfigurableFieldSpec]: - """Get the unique config specs from a sequence of config specs.""" - grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) - unique: List[ConfigurableFieldSpec] = [] - for id, dupes in grouped: - first = next(dupes) - others = list(dupes) - if len(others) == 0: - unique.append(first) - elif all(o == first for o in others): - unique.append(first) - else: - raise ValueError( - "RunnableSequence contains conflicting config specs" - f"for {id}: {[first] + others}" - ) - return unique diff --git a/libs/langchain/langchain/schema/storage.py b/libs/langchain/langchain/schema/storage.py index bae5adc2b8e..7ed3443be8d 100644 --- a/libs/langchain/langchain/schema/storage.py +++ b/libs/langchain/langchain/schema/storage.py @@ -1,53 +1,3 @@ -from abc import ABC, abstractmethod -from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union +from langchain_core.schema.storage import BaseStore -K = TypeVar("K") -V = TypeVar("V") - - -class BaseStore(Generic[K, V], ABC): - """Abstract interface for a key-value store.""" - - @abstractmethod - def mget(self, keys: Sequence[K]) -> List[Optional[V]]: - """Get the values associated with the given keys. - - Args: - keys (Sequence[K]): A sequence of keys. - - Returns: - A sequence of optional values associated with the keys. - If a key is not found, the corresponding value will be None. - """ - - @abstractmethod - def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: - """Set the values for the given keys. - - Args: - key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. - """ - - @abstractmethod - def mdelete(self, keys: Sequence[K]) -> None: - """Delete the given keys and their associated values. - - Args: - keys (Sequence[K]): A sequence of keys to delete. - """ - - @abstractmethod - def yield_keys( - self, *, prefix: Optional[str] = None - ) -> Union[Iterator[K], Iterator[str]]: - """Get an iterator over keys that match the given prefix. - - Args: - prefix (str): The prefix to match. - - Returns: - Iterator[K | str]: An iterator over keys that match the given prefix. - - This method is allowed to return an iterator over either K or str - depending on what makes more sense for the given store. - """ +__all__ = ["BaseStore"] diff --git a/libs/langchain/langchain/schema/vectorstore.py b/libs/langchain/langchain/schema/vectorstore.py index 5e5e08a7a0c..5be4e018853 100644 --- a/libs/langchain/langchain/schema/vectorstore.py +++ b/libs/langchain/langchain/schema/vectorstore.py @@ -1,702 +1,3 @@ -from __future__ import annotations +from langchain_core.schema.vectorstore import VectorStore, VectorStoreRetriever -import asyncio -import logging -import math -import warnings -from abc import ABC, abstractmethod -from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Collection, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - TypeVar, -) - -from langchain.pydantic_v1 import Field, root_validator -from langchain.schema import BaseRetriever -from langchain.schema.document import Document -from langchain.schema.embeddings import Embeddings - -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, - ) - -logger = logging.getLogger(__name__) - -VST = TypeVar("VST", bound="VectorStore") - - -class VectorStore(ABC): - """Interface for vector store.""" - - @abstractmethod - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters - - Returns: - List of ids from adding the texts into the vectorstore. - """ - - @property - def embeddings(self) -> Optional[Embeddings]: - """Access the query embedding object if available.""" - logger.debug( - f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}" - ) - return None - - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - - raise NotImplementedError("delete method must be implemented by subclass.") - - async def adelete( - self, ids: Optional[List[str]] = None, **kwargs: Any - ) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - - raise NotImplementedError("delete method must be implemented by subclass.") - - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.add_texts, **kwargs), texts, metadatas - ) - - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: - """Run more documents through the embeddings and add to the vectorstore. - - Args: - documents (List[Document]: Documents to add to the vectorstore. - - Returns: - List[str]: List of IDs of the added texts. - """ - # TODO: Handle the case where the user doesn't provide ids on the Collection - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - return self.add_texts(texts, metadatas, **kwargs) - - async def aadd_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[str]: - """Run more documents through the embeddings and add to the vectorstore. - - Args: - documents (List[Document]: Documents to add to the vectorstore. - - Returns: - List[str]: List of IDs of the added texts. - """ - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - return await self.aadd_texts(texts, metadatas, **kwargs) - - def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: - """Return docs most similar to query using specified search type.""" - if search_type == "similarity": - return self.similarity_search(query, **kwargs) - elif search_type == "mmr": - return self.max_marginal_relevance_search(query, **kwargs) - else: - raise ValueError( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity' or 'mmr'." - ) - - async def asearch( - self, query: str, search_type: str, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to query using specified search type.""" - if search_type == "similarity": - return await self.asimilarity_search(query, **kwargs) - elif search_type == "mmr": - return await self.amax_marginal_relevance_search(query, **kwargs) - else: - raise ValueError( - f"search_type of {search_type} not allowed. Expected " - "search_type to be 'similarity' or 'mmr'." - ) - - @abstractmethod - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to query.""" - - @staticmethod - def _euclidean_relevance_score_fn(distance: float) -> float: - """Return a similarity score on a scale [0, 1].""" - # The 'correct' relevance function - # may differ depending on a few things, including: - # - the distance / similarity metric used by the VectorStore - # - the scale of your embeddings (OpenAI's are unit normed. Many - # others are not!) - # - embedding dimensionality - # - etc. - # This function converts the euclidean norm of normalized embeddings - # (0 is most similar, sqrt(2) most dissimilar) - # to a similarity function (0 to 1) - return 1.0 - distance / math.sqrt(2) - - @staticmethod - def _cosine_relevance_score_fn(distance: float) -> float: - """Normalize the distance to a score on a scale [0, 1].""" - - return 1.0 - distance - - @staticmethod - def _max_inner_product_relevance_score_fn(distance: float) -> float: - """Normalize the distance to a score on a scale [0, 1].""" - if distance > 0: - return 1.0 - distance - - return -1.0 * distance - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - - Vectorstores should define their own selection based method of relevance. - """ - raise NotImplementedError - - def similarity_search_with_score( - self, *args: Any, **kwargs: Any - ) -> List[Tuple[Document, float]]: - """Run similarity search with distance.""" - raise NotImplementedError - - async def asimilarity_search_with_score( - self, *args: Any, **kwargs: Any - ) -> List[Tuple[Document, float]]: - """Run similarity search with distance asynchronously.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - func = partial(self.similarity_search_with_score, *args, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """ - Default similarity search with relevance scores. Modify if necessary - in subclass. - Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - relevance_score_fn = self._select_relevance_score_fn() - docs_and_scores = self.similarity_search_with_score(query, k, **kwargs) - return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] - - async def _asimilarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """ - Default async similarity search with relevance scores. Modify if necessary - in subclass. - Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - relevance_score_fn = self._select_relevance_score_fn() - docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs) - return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] - - def similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - score_threshold = kwargs.pop("score_threshold", None) - - docs_and_similarities = self._similarity_search_with_relevance_scores( - query, k=k, **kwargs - ) - if any( - similarity < 0.0 or similarity > 1.0 - for _, similarity in docs_and_similarities - ): - warnings.warn( - "Relevance scores must be between" - f" 0 and 1, got {docs_and_similarities}" - ) - - if score_threshold is not None: - docs_and_similarities = [ - (doc, similarity) - for doc, similarity in docs_and_similarities - if similarity >= score_threshold - ] - if len(docs_and_similarities) == 0: - warnings.warn( - "No relevant docs were retrieved using the relevance score" - f" threshold {score_threshold}" - ) - return docs_and_similarities - - async def asimilarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1], asynchronously. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - score_threshold = kwargs.pop("score_threshold", None) - - docs_and_similarities = await self._asimilarity_search_with_relevance_scores( - query, k=k, **kwargs - ) - if any( - similarity < 0.0 or similarity > 1.0 - for _, similarity in docs_and_similarities - ): - warnings.warn( - "Relevance scores must be between" - f" 0 and 1, got {docs_and_similarities}" - ) - - if score_threshold is not None: - docs_and_similarities = [ - (doc, similarity) - for doc, similarity in docs_and_similarities - if similarity >= score_threshold - ] - if len(docs_and_similarities) == 0: - warnings.warn( - "No relevant docs were retrieved using the relevance score" - f" threshold {score_threshold}" - ) - return docs_and_similarities - - async def asimilarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to query.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - func = partial(self.similarity_search, query, k=k, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) - - def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query vector. - """ - raise NotImplementedError - - async def asimilarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to embedding vector.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - raise NotImplementedError - - async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - func = partial( - self.max_marginal_relevance_search, - query, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - **kwargs, - ) - return await asyncio.get_event_loop().run_in_executor(None, func) - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - raise NotImplementedError - - async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - raise NotImplementedError - - @classmethod - def from_documents( - cls: Type[VST], - documents: List[Document], - embedding: Embeddings, - **kwargs: Any, - ) -> VST: - """Return VectorStore initialized from documents and embeddings.""" - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) - - @classmethod - async def afrom_documents( - cls: Type[VST], - documents: List[Document], - embedding: Embeddings, - **kwargs: Any, - ) -> VST: - """Return VectorStore initialized from documents and embeddings.""" - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs) - - @classmethod - @abstractmethod - def from_texts( - cls: Type[VST], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> VST: - """Return VectorStore initialized from texts and embeddings.""" - - @classmethod - async def afrom_texts( - cls: Type[VST], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> VST: - """Return VectorStore initialized from texts and embeddings.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas - ) - - def _get_retriever_tags(self) -> List[str]: - """Get tags for retriever.""" - tags = [self.__class__.__name__] - if self.embeddings: - tags.append(self.embeddings.__class__.__name__) - return tags - - def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: - """Return VectorStoreRetriever initialized from this VectorStore. - - Args: - search_type (Optional[str]): Defines the type of search that - the Retriever should perform. - Can be "similarity" (default), "mmr", or - "similarity_score_threshold". - search_kwargs (Optional[Dict]): Keyword arguments to pass to the - search function. Can include things like: - k: Amount of documents to return (Default: 4) - score_threshold: Minimum relevance threshold - for similarity_score_threshold - fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) - lambda_mult: Diversity of results returned by MMR; - 1 for minimum diversity and 0 for maximum. (Default: 0.5) - filter: Filter by document metadata - - Returns: - VectorStoreRetriever: Retriever class for VectorStore. - - Examples: - - .. code-block:: python - - # Retrieve more documents with higher diversity - # Useful if your dataset has many similar documents - docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 6, 'lambda_mult': 0.25} - ) - - # Fetch more documents for the MMR algorithm to consider - # But only return the top 5 - docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 5, 'fetch_k': 50} - ) - - # Only retrieve documents that have a relevance score - # Above a certain threshold - docsearch.as_retriever( - search_type="similarity_score_threshold", - search_kwargs={'score_threshold': 0.8} - ) - - # Only get the single most similar document from the dataset - docsearch.as_retriever(search_kwargs={'k': 1}) - - # Use a filter to only retrieve documents from a specific paper - docsearch.as_retriever( - search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}} - ) - """ - tags = kwargs.pop("tags", None) or [] - tags.extend(self._get_retriever_tags()) - return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) - - -class VectorStoreRetriever(BaseRetriever): - """Base Retriever class for VectorStore.""" - - vectorstore: VectorStore - """VectorStore to use for retrieval.""" - search_type: str = "similarity" - """Type of search to perform. Defaults to "similarity".""" - search_kwargs: dict = Field(default_factory=dict) - """Keyword arguments to pass to the search function.""" - allowed_search_types: ClassVar[Collection[str]] = ( - "similarity", - "similarity_score_threshold", - "mmr", - ) - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator() - def validate_search_type(cls, values: Dict) -> Dict: - """Validate search type.""" - search_type = values["search_type"] - if search_type not in cls.allowed_search_types: - raise ValueError( - f"search_type of {search_type} not allowed. Valid values are: " - f"{cls.allowed_search_types}" - ) - if search_type == "similarity_score_threshold": - score_threshold = values["search_kwargs"].get("score_threshold") - if (score_threshold is None) or (not isinstance(score_threshold, float)): - raise ValueError( - "`score_threshold` is not specified with a float value(0~1) " - "in `search_kwargs`." - ) - return values - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) - elif self.search_type == "similarity_score_threshold": - docs_and_similarities = ( - self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs - ) - ) - docs = [doc for doc, _ in docs_and_similarities] - elif self.search_type == "mmr": - docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs - ) - else: - raise ValueError(f"search_type of {self.search_type} not allowed.") - return docs - - async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "similarity": - docs = await self.vectorstore.asimilarity_search( - query, **self.search_kwargs - ) - elif self.search_type == "similarity_score_threshold": - docs_and_similarities = ( - await self.vectorstore.asimilarity_search_with_relevance_scores( - query, **self.search_kwargs - ) - ) - docs = [doc for doc, _ in docs_and_similarities] - elif self.search_type == "mmr": - docs = await self.vectorstore.amax_marginal_relevance_search( - query, **self.search_kwargs - ) - else: - raise ValueError(f"search_type of {self.search_type} not allowed.") - return docs - - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: - """Add documents to vectorstore.""" - return self.vectorstore.add_documents(documents, **kwargs) - - async def aadd_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[str]: - """Add documents to vectorstore.""" - return await self.vectorstore.aadd_documents(documents, **kwargs) +__all__ = ["VectorStore", "VectorStoreRetriever"] diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index f7234528649..95483cce516 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional, Union +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.prompt_template import BasePromptTemplate from langsmith import RunEvaluator from langchain.evaluation.criteria.eval_chain import CRITERIA_TYPE @@ -12,10 +16,6 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator from langchain.evaluation.string_distance.base import ( StringDistance as StringDistanceEnum, ) -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema.embeddings import Embeddings -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.prompt_template import BasePromptTemplate class EvalConfig(BaseModel): diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index a0f8c4fc4ce..4471ed18289 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -3,9 +3,10 @@ import threading from typing import Any, Dict, Optional, Sequence from uuid import UUID +from langchain_core.schema.document import Document +from langchain_core.schema.output import LLMResult + from langchain.callbacks import base as base_callbacks -from langchain.schema.document import Document -from langchain.schema.output import LLMResult class ProgressBarCallback(base_callbacks.BaseCallbackHandler): diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 403348c523f..38593e27b00 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -20,6 +20,13 @@ from typing import ( cast, ) +from langchain_core._api import warn_deprecated +from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda +from langchain_core.runnables import config as runnable_config +from langchain_core.runnables import utils as runnable_utils +from langchain_core.schema import ChatResult, LLMResult +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.messages import BaseMessage, messages_from_dict from langsmith.client import Client from langsmith.evaluation import RunEvaluator from langsmith.run_helpers import as_runnable, is_traceable_function @@ -27,7 +34,6 @@ from langsmith.schemas import Dataset, DataType, Example from langsmith.utils import LangSmithError from requests import HTTPError -from langchain._api import warn_deprecated from langchain.callbacks.manager import Callbacks from langchain.callbacks.tracers.evaluation import ( EvaluatorCallbackHandler, @@ -41,12 +47,6 @@ from langchain.evaluation.schema import ( PairwiseStringEvaluator, StringEvaluator, ) -from langchain.schema import ChatResult, LLMResult -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.messages import BaseMessage, messages_from_dict -from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda -from langchain.schema.runnable import config as runnable_config -from langchain.schema.runnable import utils as runnable_utils from langchain.smith import evaluation as smith_eval from langchain.smith.evaluation import config as smith_eval_config from langchain.smith.evaluation import name_generation, progress diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index 36ddd67db66..6c9a362a7f6 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -4,6 +4,11 @@ from __future__ import annotations from abc import abstractmethod from typing import Any, Dict, List, Optional +from langchain_core.load.dump import dumpd +from langchain_core.load.load import load +from langchain_core.load.serializable import Serializable +from langchain_core.schema import RUN_KEY, messages_from_dict +from langchain_core.schema.messages import BaseMessage, get_buffer_string from langsmith import EvaluationResult, RunEvaluator from langsmith.schemas import DataType, Example, Run @@ -13,11 +18,6 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.evaluation.schema import StringEvaluator -from langchain.load.dump import dumpd -from langchain.load.load import load -from langchain.load.serializable import Serializable -from langchain.schema import RUN_KEY, messages_from_dict -from langchain.schema.messages import BaseMessage, get_buffer_string def _get_messages_from_run_dict(messages: List[dict]) -> List[BaseMessage]: diff --git a/libs/langchain/langchain/storage/_lc_store.py b/libs/langchain/langchain/storage/_lc_store.py index be528e77480..73916b9ded2 100644 --- a/libs/langchain/langchain/storage/_lc_store.py +++ b/libs/langchain/langchain/storage/_lc_store.py @@ -1,10 +1,11 @@ """Create a key-value store for any langchain serializable object.""" from typing import Callable, Optional -from langchain.load.dump import dumps -from langchain.load.load import loads -from langchain.load.serializable import Serializable -from langchain.schema import BaseStore, Document +from langchain_core.load.dump import dumps +from langchain_core.load.load import loads +from langchain_core.load.serializable import Serializable +from langchain_core.schema import BaseStore, Document + from langchain.storage.encoder_backed import EncoderBackedStore diff --git a/libs/langchain/langchain/storage/encoder_backed.py b/libs/langchain/langchain/storage/encoder_backed.py index 4a713cc3912..026d76f4dd4 100644 --- a/libs/langchain/langchain/storage/encoder_backed.py +++ b/libs/langchain/langchain/storage/encoder_backed.py @@ -10,7 +10,7 @@ from typing import ( Union, ) -from langchain.schema import BaseStore +from langchain_core.schema import BaseStore K = TypeVar("K") V = TypeVar("V") diff --git a/libs/langchain/langchain/storage/exceptions.py b/libs/langchain/langchain/storage/exceptions.py index 2f36a7615c7..fedc9c7bf87 100644 --- a/libs/langchain/langchain/storage/exceptions.py +++ b/libs/langchain/langchain/storage/exceptions.py @@ -1,4 +1,4 @@ -from langchain.schema import LangChainException +from langchain_core.schema import LangChainException class InvalidKeyException(LangChainException): diff --git a/libs/langchain/langchain/storage/file_system.py b/libs/langchain/langchain/storage/file_system.py index d197497d8ae..3cead7e0236 100644 --- a/libs/langchain/langchain/storage/file_system.py +++ b/libs/langchain/langchain/storage/file_system.py @@ -2,7 +2,8 @@ import re from pathlib import Path from typing import Iterator, List, Optional, Sequence, Tuple, Union -from langchain.schema import BaseStore +from langchain_core.schema import BaseStore + from langchain.storage.exceptions import InvalidKeyException diff --git a/libs/langchain/langchain/storage/in_memory.py b/libs/langchain/langchain/storage/in_memory.py index 3350f75f4d4..48e014e839d 100644 --- a/libs/langchain/langchain/storage/in_memory.py +++ b/libs/langchain/langchain/storage/in_memory.py @@ -5,7 +5,7 @@ primarily for unit testing purposes. """ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple -from langchain.schema import BaseStore +from langchain_core.schema import BaseStore class InMemoryStore(BaseStore[str, Any]): diff --git a/libs/langchain/langchain/storage/redis.py b/libs/langchain/langchain/storage/redis.py index bfb3e4a4ae0..6b5efbbc5e1 100644 --- a/libs/langchain/langchain/storage/redis.py +++ b/libs/langchain/langchain/storage/redis.py @@ -1,6 +1,7 @@ from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast -from langchain.schema import BaseStore +from langchain_core.schema import BaseStore + from langchain.utilities.redis import get_client diff --git a/libs/langchain/langchain/storage/upstash_redis.py b/libs/langchain/langchain/storage/upstash_redis.py index e9dc4fd6557..194982203bd 100644 --- a/libs/langchain/langchain/storage/upstash_redis.py +++ b/libs/langchain/langchain/storage/upstash_redis.py @@ -1,6 +1,6 @@ from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast -from langchain.schema import BaseStore +from langchain_core.schema import BaseStore class UpstashRedisStore(BaseStore[str, str]): diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index 25c121866ab..095693e3dae 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -51,9 +51,9 @@ from typing import ( ) import requests +from langchain_core.schema import BaseDocumentTransformer from langchain.docstore.document import Document -from langchain.schema import BaseDocumentTransformer logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/ainetwork/app.py b/libs/langchain/langchain/tools/ainetwork/app.py index 64c32046e86..93c7bad3940 100644 --- a/libs/langchain/langchain/tools/ainetwork/app.py +++ b/libs/langchain/langchain/tools/ainetwork/app.py @@ -3,8 +3,9 @@ import json from enum import Enum from typing import List, Optional, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.ainetwork.base import AINBaseTool diff --git a/libs/langchain/langchain/tools/ainetwork/base.py b/libs/langchain/langchain/tools/ainetwork/base.py index 0d40f403896..c38724fb019 100644 --- a/libs/langchain/langchain/tools/ainetwork/base.py +++ b/libs/langchain/langchain/tools/ainetwork/base.py @@ -5,8 +5,9 @@ import threading from enum import Enum from typing import TYPE_CHECKING, Any, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.ainetwork.utils import authenticate from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/ainetwork/owner.py b/libs/langchain/langchain/tools/ainetwork/owner.py index 33c182ac710..60e43a9a1de 100644 --- a/libs/langchain/langchain/tools/ainetwork/owner.py +++ b/libs/langchain/langchain/tools/ainetwork/owner.py @@ -2,8 +2,9 @@ import builtins import json from typing import List, Optional, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.ainetwork.base import AINBaseTool, OperationType diff --git a/libs/langchain/langchain/tools/ainetwork/rule.py b/libs/langchain/langchain/tools/ainetwork/rule.py index 030edb2eb37..768c9f36e38 100644 --- a/libs/langchain/langchain/tools/ainetwork/rule.py +++ b/libs/langchain/langchain/tools/ainetwork/rule.py @@ -2,8 +2,9 @@ import builtins import json from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.ainetwork.base import AINBaseTool, OperationType diff --git a/libs/langchain/langchain/tools/ainetwork/transfer.py b/libs/langchain/langchain/tools/ainetwork/transfer.py index 04f15c6748b..eab8cc8810d 100644 --- a/libs/langchain/langchain/tools/ainetwork/transfer.py +++ b/libs/langchain/langchain/tools/ainetwork/transfer.py @@ -1,8 +1,9 @@ import json from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.ainetwork.base import AINBaseTool diff --git a/libs/langchain/langchain/tools/ainetwork/value.py b/libs/langchain/langchain/tools/ainetwork/value.py index 844b98e9968..3153a1e419d 100644 --- a/libs/langchain/langchain/tools/ainetwork/value.py +++ b/libs/langchain/langchain/tools/ainetwork/value.py @@ -2,8 +2,9 @@ import builtins import json from typing import Optional, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.ainetwork.base import AINBaseTool, OperationType diff --git a/libs/langchain/langchain/tools/amadeus/base.py b/libs/langchain/langchain/tools/amadeus/base.py index c2db135135a..6815bc9f3b1 100644 --- a/libs/langchain/langchain/tools/amadeus/base.py +++ b/libs/langchain/langchain/tools/amadeus/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from langchain.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field + from langchain.tools.amadeus.utils import authenticate from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/amadeus/closest_airport.py b/libs/langchain/langchain/tools/amadeus/closest_airport.py index ff55239e04d..55e2eac0afb 100644 --- a/libs/langchain/langchain/tools/amadeus/closest_airport.py +++ b/libs/langchain/langchain/tools/amadeus/closest_airport.py @@ -1,9 +1,10 @@ from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.amadeus.base import AmadeusBaseTool diff --git a/libs/langchain/langchain/tools/amadeus/flight_search.py b/libs/langchain/langchain/tools/amadeus/flight_search.py index 6a82e9bc2e7..603a397a7d6 100644 --- a/libs/langchain/langchain/tools/amadeus/flight_search.py +++ b/libs/langchain/langchain/tools/amadeus/flight_search.py @@ -2,8 +2,9 @@ import logging from datetime import datetime as dt from typing import Dict, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.amadeus.base import AmadeusBaseTool logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/arxiv/tool.py b/libs/langchain/langchain/tools/arxiv/tool.py index c6b8e98e9dc..c155264b085 100644 --- a/libs/langchain/langchain/tools/arxiv/tool.py +++ b/libs/langchain/langchain/tools/arxiv/tool.py @@ -2,8 +2,9 @@ from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.arxiv import ArxivAPIWrapper diff --git a/libs/langchain/langchain/tools/azure_cognitive_services/form_recognizer.py b/libs/langchain/langchain/tools/azure_cognitive_services/form_recognizer.py index b5dce896e39..6c00104e897 100644 --- a/libs/langchain/langchain/tools/azure_cognitive_services/form_recognizer.py +++ b/libs/langchain/langchain/tools/azure_cognitive_services/form_recognizer.py @@ -3,8 +3,9 @@ from __future__ import annotations import logging from typing import Any, Dict, List, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.azure_cognitive_services.utils import detect_file_src_type from langchain.tools.base import BaseTool from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/azure_cognitive_services/image_analysis.py b/libs/langchain/langchain/tools/azure_cognitive_services/image_analysis.py index b2032fcb994..efcab61ea2e 100644 --- a/libs/langchain/langchain/tools/azure_cognitive_services/image_analysis.py +++ b/libs/langchain/langchain/tools/azure_cognitive_services/image_analysis.py @@ -3,8 +3,9 @@ from __future__ import annotations import logging from typing import Any, Dict, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.azure_cognitive_services.utils import detect_file_src_type from langchain.tools.base import BaseTool from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/azure_cognitive_services/speech2text.py b/libs/langchain/langchain/tools/azure_cognitive_services/speech2text.py index 0f7f82937e7..ec37e4d8974 100644 --- a/libs/langchain/langchain/tools/azure_cognitive_services/speech2text.py +++ b/libs/langchain/langchain/tools/azure_cognitive_services/speech2text.py @@ -4,8 +4,9 @@ import logging import time from typing import Any, Dict, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.azure_cognitive_services.utils import ( detect_file_src_type, download_audio_from_url, diff --git a/libs/langchain/langchain/tools/azure_cognitive_services/text2speech.py b/libs/langchain/langchain/tools/azure_cognitive_services/text2speech.py index fee35591f38..3ee671dbffc 100644 --- a/libs/langchain/langchain/tools/azure_cognitive_services/text2speech.py +++ b/libs/langchain/langchain/tools/azure_cognitive_services/text2speech.py @@ -4,8 +4,9 @@ import logging import tempfile from typing import Any, Dict, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.base import BaseTool from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2d3cc7ac10b..ca8323e1825 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -1,845 +1,19 @@ -"""Base implementation for tools or skills.""" -from __future__ import annotations - -import asyncio -import inspect -import warnings -from abc import abstractmethod -from functools import partial -from inspect import signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union - -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForToolRun, - CallbackManager, - CallbackManagerForToolRun, - Callbacks, +from langchain_core.tool import ( + BaseTool, + SchemaAnnotationError, + StructuredTool, + Tool, + ToolException, + create_schema_from_function, + tool, ) -from langchain.load.serializable import Serializable -from langchain.pydantic_v1 import ( - BaseModel, - Extra, - Field, - create_model, - root_validator, - validate_arguments, -) -from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable - -class SchemaAnnotationError(TypeError): - """Raised when 'args_schema' is missing or has an incorrect type annotation.""" - - -def _create_subset_model( - name: str, model: BaseModel, field_names: list -) -> Type[BaseModel]: - """Create a pydantic model with only a subset of model's fields.""" - fields = {} - for field_name in field_names: - field = model.__fields__[field_name] - fields[field_name] = (field.outer_type_, field.field_info) - return create_model(name, **fields) # type: ignore - - -def _get_filtered_args( - inferred_model: Type[BaseModel], - func: Callable, -) -> dict: - """Get the arguments from a function's signature.""" - schema = inferred_model.schema()["properties"] - valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} - - -class _SchemaConfig: - """Configuration for the pydantic model.""" - - extra: Any = Extra.forbid - arbitrary_types_allowed: bool = True - - -def create_schema_from_function( - model_name: str, - func: Callable, -) -> Type[BaseModel]: - """Create a pydantic schema from a function's signature. - Args: - model_name: Name to assign to the generated pydandic schema - func: Function to generate the schema from - Returns: - A pydantic model with the same arguments as the function - """ - # https://docs.pydantic.dev/latest/usage/validation_decorator/ - validated = validate_arguments(func, config=_SchemaConfig) # type: ignore - inferred_model = validated.model # type: ignore - if "run_manager" in inferred_model.__fields__: - del inferred_model.__fields__["run_manager"] - if "callbacks" in inferred_model.__fields__: - del inferred_model.__fields__["callbacks"] - # Pydantic adds placeholder virtual fields we need to strip - valid_properties = _get_filtered_args(inferred_model, func) - return _create_subset_model( - f"{model_name}Schema", inferred_model, list(valid_properties) - ) - - -class ToolException(Exception): - """An optional exception that tool throws when execution error occurs. - - When this exception is thrown, the agent will not stop working, - but will handle the exception according to the handle_tool_error - variable of the tool, and the processing result will be returned - to the agent as observation, and printed in red on the console. - """ - - pass - - -class BaseTool(RunnableSerializable[Union[str, Dict], Any]): - """Interface LangChain tools must implement.""" - - def __init_subclass__(cls, **kwargs: Any) -> None: - """Create the definition of the new tool class.""" - super().__init_subclass__(**kwargs) - - args_schema_type = cls.__annotations__.get("args_schema", None) - - if args_schema_type is not None: - if args_schema_type is None or args_schema_type == BaseModel: - # Throw errors for common mis-annotations. - # TODO: Use get_args / get_origin and fully - # specify valid annotations. - typehint_mandate = """ -class ChildTool(BaseTool): - ... - args_schema: Type[BaseModel] = SchemaClass - ...""" - name = cls.__name__ - raise SchemaAnnotationError( - f"Tool definition for {name} must include valid type annotations" - f" for argument 'args_schema' to behave as expected.\n" - f"Expected annotation of 'Type[BaseModel]'" - f" but got '{args_schema_type}'.\n" - f"Expected class looks like:\n" - f"{typehint_mandate}" - ) - - name: str - """The unique name of the tool that clearly communicates its purpose.""" - description: str - """Used to tell the model how/when/why to use the tool. - - You can provide few-shot examples as a part of the description. - """ - args_schema: Optional[Type[BaseModel]] = None - """Pydantic model class to validate and parse the tool's input arguments.""" - return_direct: bool = False - """Whether to return the tool's output directly. Setting this to True means - - that after the tool is called, the AgentExecutor will stop looping. - """ - verbose: bool = False - """Whether to log the tool's progress.""" - - callbacks: Callbacks = Field(default=None, exclude=True) - """Callbacks to be called during tool execution.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) - """Deprecated. Please use callbacks instead.""" - tags: Optional[List[str]] = None - """Optional list of tags associated with the tool. Defaults to None - These tags will be associated with each call to this tool, - and passed as arguments to the handlers defined in `callbacks`. - You can use these to eg identify a specific instance of a tool with its use case. - """ - metadata: Optional[Dict[str, Any]] = None - """Optional metadata associated with the tool. Defaults to None - This metadata will be associated with each call to this tool, - and passed as arguments to the handlers defined in `callbacks`. - You can use these to eg identify a specific instance of a tool with its use case. - """ - - handle_tool_error: Optional[ - Union[bool, str, Callable[[ToolException], str]] - ] = False - """Handle the content of the ToolException thrown.""" - - class Config(Serializable.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - def is_single_input(self) -> bool: - """Whether the tool only accepts a single input.""" - keys = {k for k in self.args if k != "kwargs"} - return len(keys) == 1 - - @property - def args(self) -> dict: - if self.args_schema is not None: - return self.args_schema.schema()["properties"] - else: - schema = create_schema_from_function(self.name, self._run) - return schema.schema()["properties"] - - # --- Runnable --- - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - """The tool's input schema.""" - if self.args_schema is not None: - return self.args_schema - else: - return create_schema_from_function(self.name, self._run) - - def invoke( - self, - input: Union[str, Dict], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Any: - config = config or {} - return self.run( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ) - - async def ainvoke( - self, - input: Union[str, Dict], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Any: - config = config or {} - return await self.arun( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, - ) - - # --- Tool --- - - def _parse_input( - self, - tool_input: Union[str, Dict], - ) -> Union[str, Dict[str, Any]]: - """Convert tool input to pydantic model.""" - input_args = self.args_schema - if isinstance(tool_input, str): - if input_args is not None: - key_ = next(iter(input_args.__fields__.keys())) - input_args.validate({key_: tool_input}) - return tool_input - else: - if input_args is not None: - result = input_args.parse_obj(tool_input) - return {k: v for k, v in result.dict().items() if k in tool_input} - return tool_input - - @root_validator() - def raise_deprecation(cls, values: Dict) -> Dict: - """Raise deprecation warning if callback_manager is used.""" - if values.get("callback_manager") is not None: - warnings.warn( - "callback_manager is deprecated. Please use callbacks instead.", - DeprecationWarning, - ) - values["callbacks"] = values.pop("callback_manager", None) - return values - - @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - """Use the tool. - - Add run_manager: Optional[CallbackManagerForToolRun] = None - to child implementations to enable tracing, - """ - - async def _arun( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - """Use the tool asynchronously. - - Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None - to child implementations to enable tracing, - """ - return await asyncio.get_running_loop().run_in_executor( - None, - partial(self._run, **kwargs), - *args, - ) - - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: - # For backwards compatibility, if run_input is a string, - # pass as a positional argument. - if isinstance(tool_input, str): - return (tool_input,), {} - else: - return (), tool_input - - def run( - self, - tool_input: Union[str, Dict], - verbose: Optional[bool] = None, - start_color: Optional[str] = "green", - color: Optional[str] = "green", - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """Run the tool.""" - parsed_input = self._parse_input(tool_input) - if not self.verbose and verbose is not None: - verbose_ = verbose - else: - verbose_ = self.verbose - callback_manager = CallbackManager.configure( - callbacks, - self.callbacks, - verbose_, - tags, - self.tags, - metadata, - self.metadata, - ) - # TODO: maybe also pass through run_manager is _run supports kwargs - new_arg_supported = signature(self._run).parameters.get("run_manager") - run_manager = callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - color=start_color, - name=run_name, - **kwargs, - ) - try: - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = ( - self._run(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported - else self._run(*tool_args, **tool_kwargs) - ) - except ToolException as e: - if not self.handle_tool_error: - run_manager.on_tool_error(e) - raise e - elif isinstance(self.handle_tool_error, bool): - if e.args: - observation = e.args[0] - else: - observation = "Tool execution error" - elif isinstance(self.handle_tool_error, str): - observation = self.handle_tool_error - elif callable(self.handle_tool_error): - observation = self.handle_tool_error(e) - else: - raise ValueError( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs - ) - return observation - except (Exception, KeyboardInterrupt) as e: - run_manager.on_tool_error(e) - raise e - else: - run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) - return observation - - async def arun( - self, - tool_input: Union[str, Dict], - verbose: Optional[bool] = None, - start_color: Optional[str] = "green", - color: Optional[str] = "green", - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - run_name: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """Run the tool asynchronously.""" - parsed_input = self._parse_input(tool_input) - if not self.verbose and verbose is not None: - verbose_ = verbose - else: - verbose_ = self.verbose - callback_manager = AsyncCallbackManager.configure( - callbacks, - self.callbacks, - verbose_, - tags, - self.tags, - metadata, - self.metadata, - ) - new_arg_supported = signature(self._arun).parameters.get("run_manager") - run_manager = await callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - color=start_color, - name=run_name, - **kwargs, - ) - try: - # We then call the tool on the tool input to get an observation - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = ( - await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported - else await self._arun(*tool_args, **tool_kwargs) - ) - except ToolException as e: - if not self.handle_tool_error: - await run_manager.on_tool_error(e) - raise e - elif isinstance(self.handle_tool_error, bool): - if e.args: - observation = e.args[0] - else: - observation = "Tool execution error" - elif isinstance(self.handle_tool_error, str): - observation = self.handle_tool_error - elif callable(self.handle_tool_error): - observation = self.handle_tool_error(e) - else: - raise ValueError( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - await run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs - ) - return observation - except (Exception, KeyboardInterrupt) as e: - await run_manager.on_tool_error(e) - raise e - else: - await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) - return observation - - def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: - """Make tool callable.""" - return self.run(tool_input, callbacks=callbacks) - - -class Tool(BaseTool): - """Tool that takes in function or coroutine directly.""" - - description: str = "" - func: Optional[Callable[..., str]] - """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[str]]] = None - """The asynchronous version of the function.""" - - # --- Runnable --- - - async def ainvoke( - self, - input: Union[str, Dict], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Any: - if not self.coroutine: - # If the tool does not implement async, fall back to default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) - - return await super().ainvoke(input, config, **kwargs) - - # --- Tool --- - - @property - def args(self) -> dict: - """The tool's input arguments.""" - if self.args_schema is not None: - return self.args_schema.schema()["properties"] - # For backwards compatibility, if the function signature is ambiguous, - # assume it takes a single string input. - return {"tool_input": {"type": "string"}} - - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: - """Convert tool input to pydantic model.""" - args, kwargs = super()._to_args_and_kwargs(tool_input) - # For backwards compatibility. The tool must be run with a single input - all_args = list(args) + list(kwargs.values()) - if len(all_args) != 1: - raise ToolException( - f"Too many arguments to single-input tool {self.name}." - f" Args: {all_args}" - ) - return tuple(all_args), {} - - def _run( - self, - *args: Any, - run_manager: Optional[CallbackManagerForToolRun] = None, - **kwargs: Any, - ) -> Any: - """Use the tool.""" - if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) - raise NotImplementedError("Tool does not support sync") - - async def _arun( - self, - *args: Any, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - **kwargs: Any, - ) -> Any: - """Use the tool asynchronously.""" - if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) - else: - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), *args - ) - - # TODO: this is for backwards compatibility, remove in future - def __init__( - self, name: str, func: Optional[Callable], description: str, **kwargs: Any - ) -> None: - """Initialize tool.""" - super(Tool, self).__init__( - name=name, func=func, description=description, **kwargs - ) - - @classmethod - def from_function( - cls, - func: Optional[Callable], - name: str, # We keep these required to support backwards compatibility - description: str, - return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, - coroutine: Optional[ - Callable[..., Awaitable[Any]] - ] = None, # This is last for compatibility, but should be after func - **kwargs: Any, - ) -> Tool: - """Initialize tool from a function.""" - if func is None and coroutine is None: - raise ValueError("Function and/or coroutine must be provided") - return cls( - name=name, - func=func, - coroutine=coroutine, - description=description, - return_direct=return_direct, - args_schema=args_schema, - **kwargs, - ) - - -class StructuredTool(BaseTool): - """Tool that can operate on any number of inputs.""" - - description: str = "" - args_schema: Type[BaseModel] = Field(..., description="The tool schema.") - """The input arguments' schema.""" - func: Optional[Callable[..., Any]] - """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[Any]]] = None - """The asynchronous version of the function.""" - - # --- Runnable --- - - async def ainvoke( - self, - input: Union[str, Dict], - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> Any: - if not self.coroutine: - # If the tool does not implement async, fall back to default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) - - return await super().ainvoke(input, config, **kwargs) - - # --- Tool --- - - @property - def args(self) -> dict: - """The tool's input arguments.""" - return self.args_schema.schema()["properties"] - - def _run( - self, - *args: Any, - run_manager: Optional[CallbackManagerForToolRun] = None, - **kwargs: Any, - ) -> Any: - """Use the tool.""" - if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) - raise NotImplementedError("Tool does not support sync") - - async def _arun( - self, - *args: Any, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - **kwargs: Any, - ) -> str: - """Use the tool asynchronously.""" - if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) - return await asyncio.get_running_loop().run_in_executor( - None, - partial(self._run, run_manager=run_manager, **kwargs), - *args, - ) - - @classmethod - def from_function( - cls, - func: Optional[Callable] = None, - coroutine: Optional[Callable[..., Awaitable[Any]]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, - infer_schema: bool = True, - **kwargs: Any, - ) -> StructuredTool: - """Create tool from a given function. - - A classmethod that helps to create a tool from a function. - - Args: - func: The function from which to create a tool - coroutine: The async function from which to create a tool - name: The name of the tool. Defaults to the function name - description: The description of the tool. Defaults to the function docstring - return_direct: Whether to return the result directly or as a callback - args_schema: The schema of the tool's input arguments - infer_schema: Whether to infer the schema from the function's signature - **kwargs: Additional arguments to pass to the tool - - Returns: - The tool - - Examples: - - .. code-block:: python - - def add(a: int, b: int) -> int: - \"\"\"Add two numbers\"\"\" - return a + b - tool = StructuredTool.from_function(add) - tool.run(1, 2) # 3 - """ - - if func is not None: - source_function = func - elif coroutine is not None: - source_function = coroutine - else: - raise ValueError("Function and/or coroutine must be provided") - name = name or source_function.__name__ - description = description or source_function.__doc__ - if description is None: - raise ValueError( - "Function must have a docstring if description not provided." - ) - - # Description example: - # search_api(query: str) - Searches the API for the query. - sig = signature(source_function) - description = f"{name}{sig} - {description.strip()}" - _args_schema = args_schema - if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", source_function) - return cls( - name=name, - func=func, - coroutine=coroutine, - args_schema=_args_schema, - description=description, - return_direct=return_direct, - **kwargs, - ) - - -def tool( - *args: Union[str, Callable, Runnable], - return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, - infer_schema: bool = True, -) -> Callable: - """Make tools out of functions, can be used with or without arguments. - - Args: - *args: The arguments to the tool. - return_direct: Whether to return directly from the tool rather - than continuing the agent loop. - args_schema: optional argument schema for user to specify - infer_schema: Whether to infer the schema of the arguments from - the function's signature. This also makes the resultant tool - accept a dictionary input to its `run()` function. - - Requires: - - Function must be of type (str) -> str - - Function must have a docstring - - Examples: - .. code-block:: python - - @tool - def search_api(query: str) -> str: - # Searches the API for the query. - return - - @tool("search", return_direct=True) - def search_api(query: str) -> str: - # Searches the API for the query. - return - """ - - def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: - if isinstance(dec_func, Runnable): - runnable = dec_func - - if runnable.input_schema.schema().get("type") != "object": - raise ValueError("Runnable must have an object schema.") - - async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: - return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) - - def invoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: - return runnable.invoke(kwargs, {"callbacks": callbacks}) - - coroutine = ainvoke_wrapper - func = invoke_wrapper - schema: Optional[Type[BaseModel]] = runnable.input_schema - description = repr(runnable) - elif inspect.iscoroutinefunction(dec_func): - coroutine = dec_func - func = None - schema = args_schema - description = None - else: - coroutine = None - func = dec_func - schema = args_schema - description = None - - if infer_schema or args_schema is not None: - return StructuredTool.from_function( - func, - coroutine, - name=tool_name, - description=description, - return_direct=return_direct, - args_schema=schema, - infer_schema=infer_schema, - ) - # If someone doesn't want a schema applied, we must treat it as - # a simple string->string function - if func.__doc__ is None: - raise ValueError( - "Function must have a docstring if " - "description not provided and infer_schema is False." - ) - return Tool( - name=tool_name, - func=func, - description=f"{tool_name} tool", - return_direct=return_direct, - coroutine=coroutine, - ) - - return _make_tool - - if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): - return _make_with_name(args[0])(args[1]) - elif len(args) == 1 and isinstance(args[0], str): - # if the argument is a string, then we use the string as the tool name - # Example usage: @tool("search", return_direct=True) - return _make_with_name(args[0]) - elif len(args) == 1 and callable(args[0]): - # if the argument is a function, then we use the function name as the tool name - # Example usage: @tool - return _make_with_name(args[0].__name__)(args[0]) - elif len(args) == 0: - # if there are no arguments, then we use the function name as the tool name - # Example usage: @tool(return_direct=True) - def _partial(func: Callable[[str], str]) -> BaseTool: - return _make_with_name(func.__name__)(func) - - return _partial - else: - raise ValueError("Too many arguments for tool decorator") +__all__ = [ + "SchemaAnnotationError", + "create_schema_from_function", + "ToolException", + "BaseTool", + "Tool", + "StructuredTool", + "tool", +] diff --git a/libs/langchain/langchain/tools/bearly/tool.py b/libs/langchain/langchain/tools/bearly/tool.py index 262768261cb..24432c27b13 100644 --- a/libs/langchain/langchain/tools/bearly/tool.py +++ b/libs/langchain/langchain/tools/bearly/tool.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import Dict, List, Type import requests +from langchain_core.pydantic_v1 import BaseModel, Field -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import Tool diff --git a/libs/langchain/langchain/tools/clickup/tool.py b/libs/langchain/langchain/tools/clickup/tool.py index 056ba21426f..83d601356fc 100644 --- a/libs/langchain/langchain/tools/clickup/tool.py +++ b/libs/langchain/langchain/tools/clickup/tool.py @@ -28,8 +28,9 @@ agent = initialize_agent( """ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.clickup import ClickupAPIWrapper diff --git a/libs/langchain/langchain/tools/dataforseo_api_search/tool.py b/libs/langchain/langchain/tools/dataforseo_api_search/tool.py index 6d6509999dd..c85f6fc3fa2 100644 --- a/libs/langchain/langchain/tools/dataforseo_api_search/tool.py +++ b/libs/langchain/langchain/tools/dataforseo_api_search/tool.py @@ -2,11 +2,12 @@ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper diff --git a/libs/langchain/langchain/tools/ddg_search/tool.py b/libs/langchain/langchain/tools/ddg_search/tool.py index bd2e459ea39..93b5b8d08dc 100644 --- a/libs/langchain/langchain/tools/ddg_search/tool.py +++ b/libs/langchain/langchain/tools/ddg_search/tool.py @@ -3,8 +3,9 @@ import warnings from typing import Any, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper diff --git a/libs/langchain/langchain/tools/e2b_data_analysis/tool.py b/libs/langchain/langchain/tools/e2b_data_analysis/tool.py index fb8d0a789db..8c99e49bc64 100644 --- a/libs/langchain/langchain/tools/e2b_data_analysis/tool.py +++ b/libs/langchain/langchain/tools/e2b_data_analysis/tool.py @@ -7,12 +7,13 @@ from io import StringIO from sys import version_info from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManager, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field, PrivateAttr from langchain.tools import BaseTool, Tool from langchain.tools.e2b_data_analysis.unparse import Unparser @@ -98,6 +99,7 @@ class E2BDataAnalysisTool(BaseTool): name = "e2b_data_analysis" args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments session: Any + description: str _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list) def __init__( diff --git a/libs/langchain/langchain/tools/edenai/audio_speech_to_text.py b/libs/langchain/langchain/tools/edenai/audio_speech_to_text.py index 28dbce47d02..c5dd68c40c3 100644 --- a/libs/langchain/langchain/tools/edenai/audio_speech_to_text.py +++ b/libs/langchain/langchain/tools/edenai/audio_speech_to_text.py @@ -6,9 +6,9 @@ import time from typing import List, Optional import requests +from langchain_core.pydantic_v1 import validator from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import validator from langchain.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/edenai/audio_text_to_speech.py b/libs/langchain/langchain/tools/edenai/audio_text_to_speech.py index 968d03c1a20..e0bbafc4795 100644 --- a/libs/langchain/langchain/tools/edenai/audio_text_to_speech.py +++ b/libs/langchain/langchain/tools/edenai/audio_text_to_speech.py @@ -4,9 +4,9 @@ import logging from typing import Dict, List, Literal, Optional import requests +from langchain_core.pydantic_v1 import Field, root_validator, validator from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field, root_validator, validator from langchain.tools.edenai.edenai_base_tool import EdenaiTool logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/edenai/edenai_base_tool.py b/libs/langchain/langchain/tools/edenai/edenai_base_tool.py index e0ebc543a8e..eb4903d03b4 100644 --- a/libs/langchain/langchain/tools/edenai/edenai_base_tool.py +++ b/libs/langchain/langchain/tools/edenai/edenai_base_tool.py @@ -5,9 +5,9 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional import requests +from langchain_core.pydantic_v1 import root_validator from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.base import BaseTool from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/eleven_labs/text2speech.py b/libs/langchain/langchain/tools/eleven_labs/text2speech.py index 170a078a8b2..455ec0afc12 100644 --- a/libs/langchain/langchain/tools/eleven_labs/text2speech.py +++ b/libs/langchain/langchain/tools/eleven_labs/text2speech.py @@ -2,8 +2,9 @@ import tempfile from enum import Enum from typing import Any, Dict, Optional, Union +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools.base import BaseTool from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/file_management/copy.py b/libs/langchain/langchain/tools/file_management/copy.py index 48fabe3be50..8ae417c5a4f 100644 --- a/libs/langchain/langchain/tools/file_management/copy.py +++ b/libs/langchain/langchain/tools/file_management/copy.py @@ -1,8 +1,9 @@ import shutil from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/delete.py b/libs/langchain/langchain/tools/file_management/delete.py index 07be6e9e10f..eb01899f264 100644 --- a/libs/langchain/langchain/tools/file_management/delete.py +++ b/libs/langchain/langchain/tools/file_management/delete.py @@ -1,8 +1,9 @@ import os from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/file_search.py b/libs/langchain/langchain/tools/file_management/file_search.py index aabdb389534..5bfa86bc329 100644 --- a/libs/langchain/langchain/tools/file_management/file_search.py +++ b/libs/langchain/langchain/tools/file_management/file_search.py @@ -2,8 +2,9 @@ import fnmatch import os from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/list_dir.py b/libs/langchain/langchain/tools/file_management/list_dir.py index 5399be28dd0..a5868cbc4ed 100644 --- a/libs/langchain/langchain/tools/file_management/list_dir.py +++ b/libs/langchain/langchain/tools/file_management/list_dir.py @@ -1,8 +1,9 @@ import os from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/move.py b/libs/langchain/langchain/tools/file_management/move.py index 4d348821a86..cd3c4d8873c 100644 --- a/libs/langchain/langchain/tools/file_management/move.py +++ b/libs/langchain/langchain/tools/file_management/move.py @@ -1,8 +1,9 @@ import shutil from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/read.py b/libs/langchain/langchain/tools/file_management/read.py index 1d1a6edf4d8..180d03a72f5 100644 --- a/libs/langchain/langchain/tools/file_management/read.py +++ b/libs/langchain/langchain/tools/file_management/read.py @@ -1,7 +1,8 @@ from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/file_management/utils.py b/libs/langchain/langchain/tools/file_management/utils.py index 21b0807e51a..b2a3632ecaa 100644 --- a/libs/langchain/langchain/tools/file_management/utils.py +++ b/libs/langchain/langchain/tools/file_management/utils.py @@ -2,7 +2,7 @@ import sys from pathlib import Path from typing import Optional -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel def is_relative_to(path: Path, root: Path) -> bool: diff --git a/libs/langchain/langchain/tools/file_management/write.py b/libs/langchain/langchain/tools/file_management/write.py index 09bac8a3547..278030cd163 100644 --- a/libs/langchain/langchain/tools/file_management/write.py +++ b/libs/langchain/langchain/tools/file_management/write.py @@ -1,7 +1,8 @@ from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, diff --git a/libs/langchain/langchain/tools/github/tool.py b/libs/langchain/langchain/tools/github/tool.py index ec67fd2b335..6099ea47d8b 100644 --- a/libs/langchain/langchain/tools/github/tool.py +++ b/libs/langchain/langchain/tools/github/tool.py @@ -9,8 +9,9 @@ To use this tool, you must first set as environment variables: """ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.github import GitHubAPIWrapper diff --git a/libs/langchain/langchain/tools/gitlab/tool.py b/libs/langchain/langchain/tools/gitlab/tool.py index fc8105c50af..f3f13a30d6c 100644 --- a/libs/langchain/langchain/tools/gitlab/tool.py +++ b/libs/langchain/langchain/tools/gitlab/tool.py @@ -9,8 +9,9 @@ To use this tool, you must first set as environment variables: """ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.gitlab import GitLabAPIWrapper diff --git a/libs/langchain/langchain/tools/gmail/base.py b/libs/langchain/langchain/tools/gmail/base.py index 8ce5e8e85b0..07fcc08da81 100644 --- a/libs/langchain/langchain/tools/gmail/base.py +++ b/libs/langchain/langchain/tools/gmail/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from langchain.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field + from langchain.tools.base import BaseTool from langchain.tools.gmail.utils import build_resource_service diff --git a/libs/langchain/langchain/tools/gmail/create_draft.py b/libs/langchain/langchain/tools/gmail/create_draft.py index 10d0de57843..7e6cfbfd38f 100644 --- a/libs/langchain/langchain/tools/gmail/create_draft.py +++ b/libs/langchain/langchain/tools/gmail/create_draft.py @@ -2,8 +2,9 @@ import base64 from email.message import EmailMessage from typing import List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.gmail.base import GmailBaseTool diff --git a/libs/langchain/langchain/tools/gmail/get_message.py b/libs/langchain/langchain/tools/gmail/get_message.py index a5f35ba71e9..94c26bdad15 100644 --- a/libs/langchain/langchain/tools/gmail/get_message.py +++ b/libs/langchain/langchain/tools/gmail/get_message.py @@ -2,8 +2,9 @@ import base64 import email from typing import Dict, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.gmail.base import GmailBaseTool from langchain.tools.gmail.utils import clean_email_body diff --git a/libs/langchain/langchain/tools/gmail/get_thread.py b/libs/langchain/langchain/tools/gmail/get_thread.py index 61754a43a81..2221a20e024 100644 --- a/libs/langchain/langchain/tools/gmail/get_thread.py +++ b/libs/langchain/langchain/tools/gmail/get_thread.py @@ -1,7 +1,8 @@ from typing import Dict, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.gmail.base import GmailBaseTool diff --git a/libs/langchain/langchain/tools/gmail/search.py b/libs/langchain/langchain/tools/gmail/search.py index f4de000f519..d5c55ad9918 100644 --- a/libs/langchain/langchain/tools/gmail/search.py +++ b/libs/langchain/langchain/tools/gmail/search.py @@ -3,8 +3,9 @@ import email from enum import Enum from typing import Any, Dict, List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.gmail.base import GmailBaseTool from langchain.tools.gmail.utils import clean_email_body diff --git a/libs/langchain/langchain/tools/gmail/send_message.py b/libs/langchain/langchain/tools/gmail/send_message.py index 7121a83d3c8..48bf2fd0dbb 100644 --- a/libs/langchain/langchain/tools/gmail/send_message.py +++ b/libs/langchain/langchain/tools/gmail/send_message.py @@ -4,8 +4,9 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Any, Dict, List, Optional, Union +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.gmail.base import GmailBaseTool diff --git a/libs/langchain/langchain/tools/google_places/tool.py b/libs/langchain/langchain/tools/google_places/tool.py index c34c5198762..5cdd09bd244 100644 --- a/libs/langchain/langchain/tools/google_places/tool.py +++ b/libs/langchain/langchain/tools/google_places/tool.py @@ -2,8 +2,9 @@ from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.google_places_api import GooglePlacesAPIWrapper diff --git a/libs/langchain/langchain/tools/google_serper/tool.py b/libs/langchain/langchain/tools/google_serper/tool.py index 23703b5bc40..a826539ceab 100644 --- a/libs/langchain/langchain/tools/google_serper/tool.py +++ b/libs/langchain/langchain/tools/google_serper/tool.py @@ -2,11 +2,12 @@ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper diff --git a/libs/langchain/langchain/tools/human/tool.py b/libs/langchain/langchain/tools/human/tool.py index 30591c004de..d8549291d04 100644 --- a/libs/langchain/langchain/tools/human/tool.py +++ b/libs/langchain/langchain/tools/human/tool.py @@ -2,8 +2,9 @@ from typing import Callable, Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/jira/tool.py b/libs/langchain/langchain/tools/jira/tool.py index 533b701b8c3..f582a307a9c 100644 --- a/libs/langchain/langchain/tools/jira/tool.py +++ b/libs/langchain/langchain/tools/jira/tool.py @@ -30,8 +30,9 @@ agent = initialize_agent( """ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.jira import JiraAPIWrapper diff --git a/libs/langchain/langchain/tools/json/tool.py b/libs/langchain/langchain/tools/json/tool.py index 9f82ce91ff4..87bd05feb92 100644 --- a/libs/langchain/langchain/tools/json/tool.py +++ b/libs/langchain/langchain/tools/json/tool.py @@ -7,7 +7,7 @@ import re from pathlib import Path from typing import Dict, List, Optional, Union -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, diff --git a/libs/langchain/langchain/tools/memorize/tool.py b/libs/langchain/langchain/tools/memorize/tool.py index 2e9dafa0b97..d7deea7a848 100644 --- a/libs/langchain/langchain/tools/memorize/tool.py +++ b/libs/langchain/langchain/tools/memorize/tool.py @@ -1,12 +1,13 @@ from abc import abstractmethod from typing import Any, Optional, Protocol, Sequence, runtime_checkable +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain.llms.gradient_ai import TrainResult -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/multion/close_session.py b/libs/langchain/langchain/tools/multion/close_session.py index c78557e56cc..c720ab00a40 100644 --- a/libs/langchain/langchain/tools/multion/close_session.py +++ b/libs/langchain/langchain/tools/multion/close_session.py @@ -1,11 +1,12 @@ import asyncio from typing import TYPE_CHECKING, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool if TYPE_CHECKING: diff --git a/libs/langchain/langchain/tools/multion/create_session.py b/libs/langchain/langchain/tools/multion/create_session.py index 008d5be0ae3..6630997d0b3 100644 --- a/libs/langchain/langchain/tools/multion/create_session.py +++ b/libs/langchain/langchain/tools/multion/create_session.py @@ -1,11 +1,12 @@ import asyncio from typing import TYPE_CHECKING, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool if TYPE_CHECKING: diff --git a/libs/langchain/langchain/tools/multion/update_session.py b/libs/langchain/langchain/tools/multion/update_session.py index 5710de95316..c41089b30ba 100644 --- a/libs/langchain/langchain/tools/multion/update_session.py +++ b/libs/langchain/langchain/tools/multion/update_session.py @@ -1,11 +1,12 @@ import asyncio from typing import TYPE_CHECKING, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool if TYPE_CHECKING: diff --git a/libs/langchain/langchain/tools/nuclia/tool.py b/libs/langchain/langchain/tools/nuclia/tool.py index e4aba0fa72c..1e14f62c9fa 100644 --- a/libs/langchain/langchain/tools/nuclia/tool.py +++ b/libs/langchain/langchain/tools/nuclia/tool.py @@ -16,12 +16,12 @@ import os from typing import Any, Dict, Optional, Type, Union import requests +from langchain_core.pydantic_v1 import BaseModel, Field from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/office365/base.py b/libs/langchain/langchain/tools/office365/base.py index 35407c663d1..26ba5d9e57f 100644 --- a/libs/langchain/langchain/tools/office365/base.py +++ b/libs/langchain/langchain/tools/office365/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from langchain.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field + from langchain.tools.base import BaseTool from langchain.tools.office365.utils import authenticate diff --git a/libs/langchain/langchain/tools/office365/create_draft_message.py b/libs/langchain/langchain/tools/office365/create_draft_message.py index c3d69bf9547..6585b048225 100644 --- a/libs/langchain/langchain/tools/office365/create_draft_message.py +++ b/libs/langchain/langchain/tools/office365/create_draft_message.py @@ -1,7 +1,8 @@ from typing import List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.office365.base import O365BaseTool diff --git a/libs/langchain/langchain/tools/office365/events_search.py b/libs/langchain/langchain/tools/office365/events_search.py index 7e1b1bd1306..0438b911624 100644 --- a/libs/langchain/langchain/tools/office365/events_search.py +++ b/libs/langchain/langchain/tools/office365/events_search.py @@ -7,8 +7,9 @@ https://learn.microsoft.com/en-us/graph/auth/ from datetime import datetime as dt from typing import Any, Dict, List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Extra, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Extra, Field from langchain.tools.office365.base import O365BaseTool from langchain.tools.office365.utils import clean_body diff --git a/libs/langchain/langchain/tools/office365/messages_search.py b/libs/langchain/langchain/tools/office365/messages_search.py index eca8ed33f6f..66cdb9dbb03 100644 --- a/libs/langchain/langchain/tools/office365/messages_search.py +++ b/libs/langchain/langchain/tools/office365/messages_search.py @@ -6,8 +6,9 @@ https://learn.microsoft.com/en-us/graph/auth/ from typing import Any, Dict, List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Extra, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Extra, Field from langchain.tools.office365.base import O365BaseTool from langchain.tools.office365.utils import clean_body diff --git a/libs/langchain/langchain/tools/office365/send_event.py b/libs/langchain/langchain/tools/office365/send_event.py index 6151dde1439..29565311949 100644 --- a/libs/langchain/langchain/tools/office365/send_event.py +++ b/libs/langchain/langchain/tools/office365/send_event.py @@ -7,8 +7,9 @@ https://learn.microsoft.com/en-us/graph/auth/ from datetime import datetime as dt from typing import List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.office365.base import O365BaseTool diff --git a/libs/langchain/langchain/tools/office365/send_message.py b/libs/langchain/langchain/tools/office365/send_message.py index 3590ba406fa..b00c9853073 100644 --- a/libs/langchain/langchain/tools/office365/send_message.py +++ b/libs/langchain/langchain/tools/office365/send_message.py @@ -1,7 +1,8 @@ from typing import List, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.office365.base import O365BaseTool diff --git a/libs/langchain/langchain/tools/openapi/utils/api_models.py b/libs/langchain/langchain/tools/openapi/utils/api_models.py index 0dc5147aeef..a3287b76ba9 100644 --- a/libs/langchain/langchain/tools/openapi/utils/api_models.py +++ b/libs/langchain/langchain/tools/openapi/utils/api_models.py @@ -15,7 +15,8 @@ from typing import ( Union, ) -from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/tools/openweathermap/tool.py b/libs/langchain/langchain/tools/openweathermap/tool.py index cf365e8ca8e..7f1bf1e4831 100644 --- a/libs/langchain/langchain/tools/openweathermap/tool.py +++ b/libs/langchain/langchain/tools/openweathermap/tool.py @@ -2,8 +2,9 @@ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.openweathermap import OpenWeatherMapAPIWrapper diff --git a/libs/langchain/langchain/tools/playwright/base.py b/libs/langchain/langchain/tools/playwright/base.py index bf9209fff8a..b9ddf8bed04 100644 --- a/libs/langchain/langchain/tools/playwright/base.py +++ b/libs/langchain/langchain/tools/playwright/base.py @@ -2,7 +2,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Tuple, Type -from langchain.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import root_validator + from langchain.tools.base import BaseTool if TYPE_CHECKING: diff --git a/libs/langchain/langchain/tools/playwright/click.py b/libs/langchain/langchain/tools/playwright/click.py index b93b69ca126..3a2c4b09c82 100644 --- a/libs/langchain/langchain/tools/playwright/click.py +++ b/libs/langchain/langchain/tools/playwright/click.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, diff --git a/libs/langchain/langchain/tools/playwright/current_page.py b/libs/langchain/langchain/tools/playwright/current_page.py index b26383f3445..1a66952d53b 100644 --- a/libs/langchain/langchain/tools/playwright/current_page.py +++ b/libs/langchain/langchain/tools/playwright/current_page.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page diff --git a/libs/langchain/langchain/tools/playwright/extract_hyperlinks.py b/libs/langchain/langchain/tools/playwright/extract_hyperlinks.py index c1c4292c2f2..347a142f3c7 100644 --- a/libs/langchain/langchain/tools/playwright/extract_hyperlinks.py +++ b/libs/langchain/langchain/tools/playwright/extract_hyperlinks.py @@ -3,11 +3,12 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page diff --git a/libs/langchain/langchain/tools/playwright/extract_text.py b/libs/langchain/langchain/tools/playwright/extract_text.py index 86d7a4b556f..6c7d8d0304e 100644 --- a/libs/langchain/langchain/tools/playwright/extract_text.py +++ b/libs/langchain/langchain/tools/playwright/extract_text.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, root_validator from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page diff --git a/libs/langchain/langchain/tools/playwright/get_elements.py b/libs/langchain/langchain/tools/playwright/get_elements.py index f5098c39d66..4a9c6b436d7 100644 --- a/libs/langchain/langchain/tools/playwright/get_elements.py +++ b/libs/langchain/langchain/tools/playwright/get_elements.py @@ -3,11 +3,12 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, List, Optional, Sequence, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page diff --git a/libs/langchain/langchain/tools/playwright/navigate.py b/libs/langchain/langchain/tools/playwright/navigate.py index 288efe0a8e0..f4d4411ab27 100644 --- a/libs/langchain/langchain/tools/playwright/navigate.py +++ b/libs/langchain/langchain/tools/playwright/navigate.py @@ -3,11 +3,12 @@ from __future__ import annotations from typing import Optional, Type from urllib.parse import urlparse +from langchain_core.pydantic_v1 import BaseModel, Field, validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field, validator from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, diff --git a/libs/langchain/langchain/tools/playwright/navigate_back.py b/libs/langchain/langchain/tools/playwright/navigate_back.py index 41f9ee0323c..97f09ee4993 100644 --- a/libs/langchain/langchain/tools/playwright/navigate_back.py +++ b/libs/langchain/langchain/tools/playwright/navigate_back.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, diff --git a/libs/langchain/langchain/tools/plugin.py b/libs/langchain/langchain/tools/plugin.py index a9b8e19c1fc..ca20bd9a63e 100644 --- a/libs/langchain/langchain/tools/plugin.py +++ b/libs/langchain/langchain/tools/plugin.py @@ -5,12 +5,12 @@ from typing import Optional, Type import requests import yaml +from langchain_core.pydantic_v1 import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/powerbi/tool.py b/libs/langchain/langchain/tools/powerbi/tool.py index c52f695c214..ba496bd800b 100644 --- a/libs/langchain/langchain/tools/powerbi/tool.py +++ b/libs/langchain/langchain/tools/powerbi/tool.py @@ -3,13 +3,14 @@ import logging from time import perf_counter from typing import Any, Dict, Optional, Tuple +from langchain_core.pydantic_v1 import Field, validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain from langchain.chat_models.openai import _import_tiktoken -from langchain.pydantic_v1 import Field, validator from langchain.tools.base import BaseTool from langchain.tools.powerbi.prompt import ( BAD_REQUEST_RESPONSE, diff --git a/libs/langchain/langchain/tools/pubmed/tool.py b/libs/langchain/langchain/tools/pubmed/tool.py index 84bfa72db64..a3d0f619cf6 100644 --- a/libs/langchain/langchain/tools/pubmed/tool.py +++ b/libs/langchain/langchain/tools/pubmed/tool.py @@ -1,7 +1,8 @@ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.pubmed import PubMedAPIWrapper diff --git a/libs/langchain/langchain/tools/requests/tool.py b/libs/langchain/langchain/tools/requests/tool.py index b47cd98da53..1e47ceb0650 100644 --- a/libs/langchain/langchain/tools/requests/tool.py +++ b/libs/langchain/langchain/tools/requests/tool.py @@ -3,7 +3,7 @@ import json from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, diff --git a/libs/langchain/langchain/tools/retriever.py b/libs/langchain/langchain/tools/retriever.py index d11c0c1d5c4..96517cbb43a 100644 --- a/libs/langchain/langchain/tools/retriever.py +++ b/libs/langchain/langchain/tools/retriever.py @@ -1,5 +1,6 @@ -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import BaseRetriever +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import BaseRetriever + from langchain.tools import Tool diff --git a/libs/langchain/langchain/tools/scenexplain/tool.py b/libs/langchain/langchain/tools/scenexplain/tool.py index 59d03125b4f..487e8534304 100644 --- a/libs/langchain/langchain/tools/scenexplain/tool.py +++ b/libs/langchain/langchain/tools/scenexplain/tool.py @@ -1,8 +1,9 @@ """Tool for the SceneXplain API.""" from typing import Optional +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.scenexplain import SceneXplainAPIWrapper diff --git a/libs/langchain/langchain/tools/searchapi/tool.py b/libs/langchain/langchain/tools/searchapi/tool.py index a0a8ed6351e..ee3b9df5f36 100644 --- a/libs/langchain/langchain/tools/searchapi/tool.py +++ b/libs/langchain/langchain/tools/searchapi/tool.py @@ -2,11 +2,12 @@ from typing import Optional +from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool from langchain.utilities.searchapi import SearchApiAPIWrapper diff --git a/libs/langchain/langchain/tools/searx_search/tool.py b/libs/langchain/langchain/tools/searx_search/tool.py index dfe99306d3a..793a60ff1aa 100644 --- a/libs/langchain/langchain/tools/searx_search/tool.py +++ b/libs/langchain/langchain/tools/searx_search/tool.py @@ -1,12 +1,14 @@ """Tool for the SearxNG search API.""" from typing import Optional +from langchain_core.pydantic_v1 import Extra + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import Extra -from langchain.tools.base import BaseTool, Field +from langchain.pydantic_v1 import Field +from langchain.tools.base import BaseTool from langchain.utilities.searx_search import SearxSearchWrapper diff --git a/libs/langchain/langchain/tools/shell/tool.py b/libs/langchain/langchain/tools/shell/tool.py index 89522a66d87..7def709e3b2 100644 --- a/libs/langchain/langchain/tools/shell/tool.py +++ b/libs/langchain/langchain/tools/shell/tool.py @@ -3,11 +3,12 @@ import platform import warnings from typing import Any, List, Optional, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/sleep/tool.py b/libs/langchain/langchain/tools/sleep/tool.py index 91906c9f088..ce2205ef732 100644 --- a/libs/langchain/langchain/tools/sleep/tool.py +++ b/libs/langchain/langchain/tools/sleep/tool.py @@ -3,11 +3,12 @@ from asyncio import sleep as asleep from time import sleep from typing import Optional, Type +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/spark_sql/tool.py b/libs/langchain/langchain/tools/spark_sql/tool.py index ad1c6f5ba15..ccb15e1f440 100644 --- a/libs/langchain/langchain/tools/spark_sql/tool.py +++ b/libs/langchain/langchain/tools/spark_sql/tool.py @@ -2,15 +2,15 @@ """Tools for interacting with Spark SQL.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema.language_model import BaseLanguageModel +from langchain_core.schema.language_model import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate from langchain.utilities.spark_sql import SparkSQL from langchain.tools.base import BaseTool from langchain.tools.spark_sql.prompt import QUERY_CHECKER diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index 5dfe8f680fd..8606c28ff24 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -2,15 +2,15 @@ """Tools for interacting with a SQL database.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.schema.language_model import BaseLanguageModel +from langchain_core.schema.language_model import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate from langchain.utilities.sql_database import SQLDatabase from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER diff --git a/libs/langchain/langchain/tools/steamship_image_generation/tool.py b/libs/langchain/langchain/tools/steamship_image_generation/tool.py index abd38c8d29b..0f25d2ee2a3 100644 --- a/libs/langchain/langchain/tools/steamship_image_generation/tool.py +++ b/libs/langchain/langchain/tools/steamship_image_generation/tool.py @@ -16,8 +16,9 @@ from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Dict, Optional +from langchain_core.pydantic_v1 import root_validator + from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import root_validator from langchain.tools import BaseTool from langchain.tools.steamship_image_generation.utils import make_image_public from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/tools/tavily_search/tool.py b/libs/langchain/langchain/tools/tavily_search/tool.py index a054cf42767..c7383e3e4a1 100644 --- a/libs/langchain/langchain/tools/tavily_search/tool.py +++ b/libs/langchain/langchain/tools/tavily_search/tool.py @@ -2,11 +2,12 @@ from typing import Dict, List, Optional, Type, Union +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.tavily_search import TavilySearchAPIWrapper diff --git a/libs/langchain/langchain/tools/vectorstore/tool.py b/libs/langchain/langchain/tools/vectorstore/tool.py index 1504bdce834..e0192243dd7 100644 --- a/libs/langchain/langchain/tools/vectorstore/tool.py +++ b/libs/langchain/langchain/tools/vectorstore/tool.py @@ -3,11 +3,12 @@ import json from typing import Any, Dict, Optional +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema.language_model import BaseLanguageModel +from langchain_core.schema.vectorstore import VectorStore + from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.llms.openai import OpenAI -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/yahoo_finance_news.py b/libs/langchain/langchain/tools/yahoo_finance_news.py index 94acba47cff..ee63f416d5a 100644 --- a/libs/langchain/langchain/tools/yahoo_finance_news.py +++ b/libs/langchain/langchain/tools/yahoo_finance_news.py @@ -1,11 +1,11 @@ from typing import Iterable, Optional +from langchain_core.schema import Document from requests.exceptions import HTTPError, ReadTimeout from urllib3.exceptions import ConnectionError from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.document_loaders.web_base import WebBaseLoader -from langchain.schema import Document from langchain.tools.base import BaseTool diff --git a/libs/langchain/langchain/tools/zapier/tool.py b/libs/langchain/langchain/tools/zapier/tool.py index 9f9011756d0..b9b9e5b6c27 100644 --- a/libs/langchain/langchain/tools/zapier/tool.py +++ b/libs/langchain/langchain/tools/zapier/tool.py @@ -81,12 +81,13 @@ agent.run(("Summarize the last email I received regarding Silicon Valley Bank. " """ from typing import Any, Dict, Optional -from langchain._api import warn_deprecated +from langchain_core._api import warn_deprecated +from langchain_core.pydantic_v1 import Field, root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain.pydantic_v1 import Field, root_validator from langchain.tools.base import BaseTool from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT from langchain.utilities.zapier import ZapierNLAWrapper diff --git a/libs/langchain/langchain/utilities/alpha_vantage.py b/libs/langchain/langchain/utilities/alpha_vantage.py index 4a1c6381a7c..31f8f12a465 100644 --- a/libs/langchain/langchain/utilities/alpha_vantage.py +++ b/libs/langchain/langchain/utilities/alpha_vantage.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/apify.py b/libs/langchain/langchain/utilities/apify.py index b3a835457a0..7a2d04680f8 100644 --- a/libs/langchain/langchain/utilities/apify.py +++ b/libs/langchain/langchain/utilities/apify.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.document import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.document import Document + from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 4927acf832b..844f2c15910 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -6,9 +6,8 @@ from enum import Enum from typing import Any, Dict, List, Literal, Mapping, Optional, Union import requests - -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema.retriever import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema.retriever import Document class ArceeRoute(str, Enum): diff --git a/libs/langchain/langchain/utilities/arxiv.py b/libs/langchain/langchain/utilities/arxiv.py index 9eef84ecb17..f74d3d4f9ee 100644 --- a/libs/langchain/langchain/utilities/arxiv.py +++ b/libs/langchain/langchain/utilities/arxiv.py @@ -4,8 +4,8 @@ import os import re from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/awslambda.py b/libs/langchain/langchain/utilities/awslambda.py index f75cf3eec28..1b497dd5dd2 100644 --- a/libs/langchain/langchain/utilities/awslambda.py +++ b/libs/langchain/langchain/utilities/awslambda.py @@ -2,7 +2,7 @@ import json from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator class LambdaWrapper(BaseModel): diff --git a/libs/langchain/langchain/utilities/bibtex.py b/libs/langchain/langchain/utilities/bibtex.py index 8dcbda88518..45d83aefea3 100644 --- a/libs/langchain/langchain/utilities/bibtex.py +++ b/libs/langchain/langchain/utilities/bibtex.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/bing_search.py b/libs/langchain/langchain/utilities/bing_search.py index 5e4c557185b..2783118a5c4 100644 --- a/libs/langchain/langchain/utilities/bing_search.py +++ b/libs/langchain/langchain/utilities/bing_search.py @@ -6,8 +6,8 @@ https://levelup.gitconnected.com/api-tutorial-how-to-use-bing-web-search-api-in- from typing import Dict, List import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/brave_search.py b/libs/langchain/langchain/utilities/brave_search.py index b3d8b40fe76..91ab8e8bbbc 100644 --- a/libs/langchain/langchain/utilities/brave_search.py +++ b/libs/langchain/langchain/utilities/brave_search.py @@ -2,9 +2,8 @@ import json from typing import List import requests - -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import Document +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import Document class BraveSearchWrapper(BaseModel): diff --git a/libs/langchain/langchain/utilities/clickup.py b/libs/langchain/langchain/utilities/clickup.py index 052dd9bdd65..375d47693c3 100644 --- a/libs/langchain/langchain/utilities/clickup.py +++ b/libs/langchain/langchain/utilities/clickup.py @@ -5,8 +5,8 @@ from dataclasses import asdict, dataclass, fields from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env DEFAULT_URL = "https://api.clickup.com/api/v2" diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index e805aabe505..cd8c98bd33e 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -1,7 +1,8 @@ """Utility that calls OpenAI's Dall-E Image Generator.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/dataforseo_api_search.py b/libs/langchain/langchain/utilities/dataforseo_api_search.py index e774ff7c935..cb29e5c3da1 100644 --- a/libs/langchain/langchain/utilities/dataforseo_api_search.py +++ b/libs/langchain/langchain/utilities/dataforseo_api_search.py @@ -4,8 +4,8 @@ from urllib.parse import quote import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/duckduckgo_search.py b/libs/langchain/langchain/utilities/duckduckgo_search.py index b293ea6d7d6..67e89875cec 100644 --- a/libs/langchain/langchain/utilities/duckduckgo_search.py +++ b/libs/langchain/langchain/utilities/duckduckgo_search.py @@ -5,7 +5,7 @@ https://pypi.org/project/duckduckgo-search/ """ from typing import Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator class DuckDuckGoSearchAPIWrapper(BaseModel): diff --git a/libs/langchain/langchain/utilities/github.py b/libs/langchain/langchain/utilities/github.py index 7ef1c2e4b78..234ec87b1e9 100644 --- a/libs/langchain/langchain/utilities/github.py +++ b/libs/langchain/langchain/utilities/github.py @@ -4,7 +4,8 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/utilities/gitlab.py b/libs/langchain/langchain/utilities/gitlab.py index 0ad8db3c995..78765637c71 100644 --- a/libs/langchain/langchain/utilities/gitlab.py +++ b/libs/langchain/langchain/utilities/gitlab.py @@ -4,7 +4,8 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/utilities/golden_query.py b/libs/langchain/langchain/utilities/golden_query.py index e94b49c1c0b..965a218ec79 100644 --- a/libs/langchain/langchain/utilities/golden_query.py +++ b/libs/langchain/langchain/utilities/golden_query.py @@ -3,8 +3,8 @@ import json from typing import Dict, Optional import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env GOLDEN_BASE_URL = "https://golden.com" diff --git a/libs/langchain/langchain/utilities/google_places_api.py b/libs/langchain/langchain/utilities/google_places_api.py index abd45394741..290b48853dd 100644 --- a/libs/langchain/langchain/utilities/google_places_api.py +++ b/libs/langchain/langchain/utilities/google_places_api.py @@ -4,7 +4,8 @@ import logging from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/google_scholar.py b/libs/langchain/langchain/utilities/google_scholar.py index 777394f1925..9c9232b7cf9 100644 --- a/libs/langchain/langchain/utilities/google_scholar.py +++ b/libs/langchain/langchain/utilities/google_scholar.py @@ -1,7 +1,8 @@ """Util that calls Google Scholar Search.""" from typing import Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/google_search.py b/libs/langchain/langchain/utilities/google_search.py index e2c115b9011..d6db48d8a7d 100644 --- a/libs/langchain/langchain/utilities/google_search.py +++ b/libs/langchain/langchain/utilities/google_search.py @@ -1,7 +1,8 @@ """Util that calls Google Search.""" from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/google_serper.py b/libs/langchain/langchain/utilities/google_serper.py index 15a76148818..9efa3bf28e1 100644 --- a/libs/langchain/langchain/utilities/google_serper.py +++ b/libs/langchain/langchain/utilities/google_serper.py @@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator from typing_extensions import Literal -from langchain.pydantic_v1 import BaseModel, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/graphql.py b/libs/langchain/langchain/utilities/graphql.py index 5bc548e881b..87be94d09c3 100644 --- a/libs/langchain/langchain/utilities/graphql.py +++ b/libs/langchain/langchain/utilities/graphql.py @@ -1,7 +1,7 @@ import json from typing import Any, Callable, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator class GraphQLAPIWrapper(BaseModel): diff --git a/libs/langchain/langchain/utilities/jira.py b/libs/langchain/langchain/utilities/jira.py index ccfd4dc7b9f..c08c9a9a5f1 100644 --- a/libs/langchain/langchain/utilities/jira.py +++ b/libs/langchain/langchain/utilities/jira.py @@ -1,7 +1,8 @@ """Util that calls Jira.""" from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/loading.py b/libs/langchain/langchain/utilities/loading.py index ea46982495a..a2337b6df41 100644 --- a/libs/langchain/langchain/utilities/loading.py +++ b/libs/langchain/langchain/utilities/loading.py @@ -1,4 +1,4 @@ -from langchain.utils.loading import try_load_from_hub +from langchain_core.utils.loading import try_load_from_hub # For backwards compatibility __all__ = ["try_load_from_hub"] diff --git a/libs/langchain/langchain/utilities/metaphor_search.py b/libs/langchain/langchain/utilities/metaphor_search.py index 6dea326934a..7b2aaa57731 100644 --- a/libs/langchain/langchain/utilities/metaphor_search.py +++ b/libs/langchain/langchain/utilities/metaphor_search.py @@ -7,8 +7,8 @@ from typing import Dict, List, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env METAPHOR_API_URL = "https://api.metaphor.systems" diff --git a/libs/langchain/langchain/utilities/openapi.py b/libs/langchain/langchain/utilities/openapi.py index bdcf1d85fb2..71263c369d5 100644 --- a/libs/langchain/langchain/utilities/openapi.py +++ b/libs/langchain/langchain/utilities/openapi.py @@ -11,8 +11,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union import requests import yaml - -from langchain.pydantic_v1 import ValidationError +from langchain_core.pydantic_v1 import ValidationError logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/openweathermap.py b/libs/langchain/langchain/utilities/openweathermap.py index 8b64a7dff64..a8aa5e7406d 100644 --- a/libs/langchain/langchain/utilities/openweathermap.py +++ b/libs/langchain/langchain/utilities/openweathermap.py @@ -1,7 +1,8 @@ """Util that calls OpenWeatherMap using PyOWM.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/powerbi.py b/libs/langchain/langchain/utilities/powerbi.py index f201bc49850..88219936260 100644 --- a/libs/langchain/langchain/utilities/powerbi.py +++ b/libs/langchain/langchain/utilities/powerbi.py @@ -9,10 +9,9 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union import aiohttp import requests from aiohttp import ServerTimeoutError +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator from requests.exceptions import Timeout -from langchain.pydantic_v1 import BaseModel, Field, root_validator, validator - logger = logging.getLogger(__name__) BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg") diff --git a/libs/langchain/langchain/utilities/pubmed.py b/libs/langchain/langchain/utilities/pubmed.py index 06e241d3917..6a2e78abb03 100644 --- a/libs/langchain/langchain/utilities/pubmed.py +++ b/libs/langchain/langchain/utilities/pubmed.py @@ -6,8 +6,8 @@ import urllib.parse import urllib.request from typing import Any, Dict, Iterator, List -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/python.py b/libs/langchain/langchain/utilities/python.py index d2f5d2fb72a..70c3119e5f6 100644 --- a/libs/langchain/langchain/utilities/python.py +++ b/libs/langchain/langchain/utilities/python.py @@ -5,7 +5,7 @@ import sys from io import StringIO from typing import Dict, Optional -from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/requests.py b/libs/langchain/langchain/utilities/requests.py index 2652eb0915f..651616ff531 100644 --- a/libs/langchain/langchain/utilities/requests.py +++ b/libs/langchain/langchain/utilities/requests.py @@ -4,8 +4,7 @@ from typing import Any, AsyncGenerator, Dict, Optional import aiohttp import requests - -from langchain.pydantic_v1 import BaseModel, Extra +from langchain_core.pydantic_v1 import BaseModel, Extra class Requests(BaseModel): diff --git a/libs/langchain/langchain/utilities/scenexplain.py b/libs/langchain/langchain/utilities/scenexplain.py index c43348ec8b8..7b181e4d5c3 100644 --- a/libs/langchain/langchain/utilities/scenexplain.py +++ b/libs/langchain/langchain/utilities/scenexplain.py @@ -8,8 +8,8 @@ You can obtain a key by following the steps below. from typing import Dict import requests +from langchain_core.pydantic_v1 import BaseModel, BaseSettings, Field, root_validator -from langchain.pydantic_v1 import BaseModel, BaseSettings, Field, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/searchapi.py b/libs/langchain/langchain/utilities/searchapi.py index b5cbf646b9e..e89552096e4 100644 --- a/libs/langchain/langchain/utilities/searchapi.py +++ b/libs/langchain/langchain/utilities/searchapi.py @@ -2,8 +2,8 @@ from typing import Any, Dict, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain.pydantic_v1 import BaseModel, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/searx_search.py b/libs/langchain/langchain/utilities/searx_search.py index 5c0b6fddc67..6d43f5bf7d0 100644 --- a/libs/langchain/langchain/utilities/searx_search.py +++ b/libs/langchain/langchain/utilities/searx_search.py @@ -132,8 +132,7 @@ from typing import Any, Dict, List, Optional import aiohttp import requests - -from langchain.pydantic_v1 import ( +from langchain_core.pydantic_v1 import ( BaseModel, Extra, Field, @@ -141,6 +140,7 @@ from langchain.pydantic_v1 import ( root_validator, validator, ) + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/serpapi.py b/libs/langchain/langchain/utilities/serpapi.py index 2fadaf4a90a..8a74701f49d 100644 --- a/libs/langchain/langchain/utilities/serpapi.py +++ b/libs/langchain/langchain/utilities/serpapi.py @@ -7,8 +7,8 @@ import sys from typing import Any, Dict, Optional, Tuple import aiohttp +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/tavily_search.py b/libs/langchain/langchain/utilities/tavily_search.py index 57267f22848..d16135db65e 100644 --- a/libs/langchain/langchain/utilities/tavily_search.py +++ b/libs/langchain/langchain/utilities/tavily_search.py @@ -7,8 +7,8 @@ from typing import Dict, List, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env TAVILY_API_URL = "https://api.tavily.com" diff --git a/libs/langchain/langchain/utilities/tensorflow_datasets.py b/libs/langchain/langchain/utilities/tensorflow_datasets.py index 680fe51d871..cc5b2fdf894 100644 --- a/libs/langchain/langchain/utilities/tensorflow_datasets.py +++ b/libs/langchain/langchain/utilities/tensorflow_datasets.py @@ -1,8 +1,8 @@ import logging from typing import Any, Callable, Dict, Iterator, List, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/twilio.py b/libs/langchain/langchain/utilities/twilio.py index 07f937357c4..0798195dfff 100644 --- a/libs/langchain/langchain/utilities/twilio.py +++ b/libs/langchain/langchain/utilities/twilio.py @@ -1,7 +1,8 @@ """Util that calls Twilio.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/wikipedia.py b/libs/langchain/langchain/utilities/wikipedia.py index dd38408e627..6df84925bda 100644 --- a/libs/langchain/langchain/utilities/wikipedia.py +++ b/libs/langchain/langchain/utilities/wikipedia.py @@ -2,8 +2,8 @@ import logging from typing import Any, Dict, List, Optional -from langchain.pydantic_v1 import BaseModel, root_validator -from langchain.schema import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.schema import Document logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/utilities/wolfram_alpha.py b/libs/langchain/langchain/utilities/wolfram_alpha.py index 1af50eaf4f3..e599511f38f 100644 --- a/libs/langchain/langchain/utilities/wolfram_alpha.py +++ b/libs/langchain/langchain/utilities/wolfram_alpha.py @@ -1,7 +1,8 @@ """Util that calls WolframAlpha.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utilities/zapier.py b/libs/langchain/langchain/utilities/zapier.py index 3c435eaafbb..d5edb760594 100644 --- a/libs/langchain/langchain/utilities/zapier.py +++ b/libs/langchain/langchain/utilities/zapier.py @@ -16,9 +16,9 @@ from typing import Any, Dict, List, Optional import aiohttp import requests +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from requests import Request, Session -from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/utils/__init__.py b/libs/langchain/langchain/utils/__init__.py index 7a3a7b759de..4a571ed3d35 100644 --- a/libs/langchain/langchain/utils/__init__.py +++ b/libs/langchain/langchain/utils/__init__.py @@ -4,17 +4,14 @@ These functions do not depend on any other LangChain module. """ -from langchain.utils.env import get_from_dict_or_env, get_from_env -from langchain.utils.formatting import StrictFormatter, formatter -from langchain.utils.input import ( +from langchain_core.utils.formatting import StrictFormatter, formatter +from langchain_core.utils.input import ( get_bolded_text, get_color_mapping, get_colored_text, print_text, ) -from langchain.utils.math import cosine_similarity, cosine_similarity_top_k -from langchain.utils.strings import comma_list, stringify_dict, stringify_value -from langchain.utils.utils import ( +from langchain_core.utils.utils import ( check_package_version, convert_to_secret_str, get_pydantic_field_names, @@ -24,6 +21,10 @@ from langchain.utils.utils import ( xor_args, ) +from langchain.utils.env import get_from_dict_or_env, get_from_env +from langchain.utils.math import cosine_similarity, cosine_similarity_top_k +from langchain.utils.strings import comma_list, stringify_dict, stringify_value + __all__ = [ "StrictFormatter", "check_package_version", diff --git a/libs/langchain/langchain/utils/aiter.py b/libs/langchain/langchain/utils/aiter.py index ca44dee3958..cab956b5d07 100644 --- a/libs/langchain/langchain/utils/aiter.py +++ b/libs/langchain/langchain/utils/aiter.py @@ -1,209 +1,3 @@ -""" -Adapted from -https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py -MIT License -""" +from langchain_core.utils.aiter import NoLock, Tee, py_anext -from collections import deque -from typing import ( - Any, - AsyncContextManager, - AsyncGenerator, - AsyncIterator, - Awaitable, - Callable, - Deque, - Generic, - Iterator, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, - overload, -) - -T = TypeVar("T") - -_no_default = object() - - -# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54 -# before 3.10, the builtin anext() was not available -def py_anext( - iterator: AsyncIterator[T], default: Union[T, Any] = _no_default -) -> Awaitable[Union[T, None, Any]]: - """Pure-Python implementation of anext() for testing purposes. - - Closely matches the builtin anext() C implementation. - Can be used to compare the built-in implementation of the inner - coroutines machinery to C-implementation of __anext__() and send() - or throw() on the returned generator. - """ - - try: - __anext__ = cast( - Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__ - ) - except AttributeError: - raise TypeError(f"{iterator!r} is not an async iterator") - - if default is _no_default: - return __anext__(iterator) - - async def anext_impl() -> Union[T, Any]: - try: - # The C code is way more low-level than this, as it implements - # all methods of the iterator protocol. In this implementation - # we're relying on higher-level coroutine concepts, but that's - # exactly what we want -- crosstest pure-Python high-level - # implementation and low-level C anext() iterators. - return await __anext__(iterator) - except StopAsyncIteration: - return default - - return anext_impl() - - -class NoLock: - """Dummy lock that provides the proper interface but no protection""" - - async def __aenter__(self) -> None: - pass - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - return False - - -async def tee_peer( - iterator: AsyncIterator[T], - # the buffer specific to this peer - buffer: Deque[T], - # the buffers of all peers, including our own - peers: List[Deque[T]], - lock: AsyncContextManager[Any], -) -> AsyncGenerator[T, None]: - """An individual iterator of a :py:func:`~.tee`""" - try: - while True: - if not buffer: - async with lock: - # Another peer produced an item while we were waiting for the lock. - # Proceed with the next loop iteration to yield the item. - if buffer: - continue - try: - item = await iterator.__anext__() - except StopAsyncIteration: - break - else: - # Append to all buffers, including our own. We'll fetch our - # item from the buffer again, instead of yielding it directly. - # This ensures the proper item ordering if any of our peers - # are fetching items concurrently. They may have buffered their - # item already. - for peer_buffer in peers: - peer_buffer.append(item) - yield buffer.popleft() - finally: - async with lock: - # this peer is done – remove its buffer - for idx, peer_buffer in enumerate(peers): # pragma: no branch - if peer_buffer is buffer: - peers.pop(idx) - break - # if we are the last peer, try and close the iterator - if not peers and hasattr(iterator, "aclose"): - await iterator.aclose() - - -class Tee(Generic[T]): - """ - Create ``n`` separate asynchronous iterators over ``iterable`` - - This splits a single ``iterable`` into multiple iterators, each providing - the same items in the same order. - All child iterators may advance separately but share the same items - from ``iterable`` -- when the most advanced iterator retrieves an item, - it is buffered until the least advanced iterator has yielded it as well. - A ``tee`` works lazily and can handle an infinite ``iterable``, provided - that all iterators advance. - - .. code-block:: python3 - - async def derivative(sensor_data): - previous, current = a.tee(sensor_data, n=2) - await a.anext(previous) # advance one iterator - return a.map(operator.sub, previous, current) - - Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead - of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked - to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method - immediately closes all children, and it can be used in an ``async with`` context - for the same effect. - - If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* - provide these items. Also, ``tee`` must internally buffer each item until the - last iterator has yielded it; if the most and least advanced iterator differ - by most data, using a :py:class:`list` is more efficient (but not lazy). - - If the underlying iterable is concurrency safe (``anext`` may be awaited - concurrently) the resulting iterators are concurrency safe as well. Otherwise, - the iterators are safe if there is only ever one single "most advanced" iterator. - To enforce sequential use of ``anext``, provide a ``lock`` - - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - - and access is automatically synchronised. - """ - - def __init__( - self, - iterable: AsyncIterator[T], - n: int = 2, - *, - lock: Optional[AsyncContextManager[Any]] = None, - ): - self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist - self._buffers: List[Deque[T]] = [deque() for _ in range(n)] - self._children = tuple( - tee_peer( - iterator=self._iterator, - buffer=buffer, - peers=self._buffers, - lock=lock if lock is not None else NoLock(), - ) - for buffer in self._buffers - ) - - def __len__(self) -> int: - return len(self._children) - - @overload - def __getitem__(self, item: int) -> AsyncIterator[T]: - ... - - @overload - def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: - ... - - def __getitem__( - self, item: Union[int, slice] - ) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: - return self._children[item] - - def __iter__(self) -> Iterator[AsyncIterator[T]]: - yield from self._children - - async def __aenter__(self) -> "Tee[T]": - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - await self.aclose() - return False - - async def aclose(self) -> None: - for child in self._children: - await child.aclose() - - -atee = Tee +__all__ = ["py_anext", "NoLock", "Tee"] diff --git a/libs/langchain/langchain/utils/formatting.py b/libs/langchain/langchain/utils/formatting.py index 3b3b597b083..212bff83613 100644 --- a/libs/langchain/langchain/utils/formatting.py +++ b/libs/langchain/langchain/utils/formatting.py @@ -1,38 +1,3 @@ -"""Utilities for formatting strings.""" -from string import Formatter -from typing import Any, List, Mapping, Sequence, Union +from langchain_core.utils.formatting import StrictFormatter - -class StrictFormatter(Formatter): - """A subclass of formatter that checks for extra keys.""" - - def check_unused_args( - self, - used_args: Sequence[Union[int, str]], - args: Sequence, - kwargs: Mapping[str, Any], - ) -> None: - """Check to see if extra parameters are passed.""" - extra = set(kwargs).difference(used_args) - if extra: - raise KeyError(extra) - - def vformat( - self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] - ) -> str: - """Check that no arguments are provided.""" - if len(args) > 0: - raise ValueError( - "No arguments should be provided, " - "everything should be passed as keyword arguments." - ) - return super().vformat(format_string, args, kwargs) - - def validate_input_variables( - self, format_string: str, input_variables: List[str] - ) -> None: - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} - super().format(format_string, **dummy_inputs) - - -formatter = StrictFormatter() +__all__ = ["StrictFormatter"] diff --git a/libs/langchain/langchain/utils/input.py b/libs/langchain/langchain/utils/input.py index 8d5ae6cc24f..563cc506076 100644 --- a/libs/langchain/langchain/utils/input.py +++ b/libs/langchain/langchain/utils/input.py @@ -1,42 +1,8 @@ -"""Handle chained inputs.""" -from typing import Dict, List, Optional, TextIO +from langchain_core.utils.input import ( + get_bolded_text, + get_color_mapping, + get_colored_text, + print_text, +) -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -def get_color_mapping( - items: List[str], excluded_colors: Optional[List] = None -) -> Dict[str, str]: - """Get mapping for items to a support color.""" - colors = list(_TEXT_COLOR_MAPPING.keys()) - if excluded_colors is not None: - colors = [c for c in colors if c not in excluded_colors] - color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} - return color_mapping - - -def get_colored_text(text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" - - -def get_bolded_text(text: str) -> str: - """Get bolded text.""" - return f"\033[1m{text}\033[0m" - - -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: - """Print text with highlighting and no end characters.""" - text_to_print = get_colored_text(text, color) if color else text - print(text_to_print, end=end, file=file) - if file: - file.flush() # ensure all printed content are written to file +__all__ = ["get_color_mapping", "get_colored_text", "get_bolded_text", "print_text"] diff --git a/libs/langchain/langchain/utils/iter.py b/libs/langchain/langchain/utils/iter.py index 60834163c3f..a4059721241 100644 --- a/libs/langchain/langchain/utils/iter.py +++ b/libs/langchain/langchain/utils/iter.py @@ -1,175 +1,3 @@ -from collections import deque -from itertools import islice -from typing import ( - Any, - ContextManager, - Deque, - Generator, - Generic, - Iterable, - Iterator, - List, - Optional, - Tuple, - TypeVar, - Union, - overload, -) +from langchain_core.utils.iter import NoLock, Tee, batch_iterate, tee_peer -from typing_extensions import Literal - -T = TypeVar("T") - - -class NoLock: - """Dummy lock that provides the proper interface but no protection""" - - def __enter__(self) -> None: - pass - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: - return False - - -def tee_peer( - iterator: Iterator[T], - # the buffer specific to this peer - buffer: Deque[T], - # the buffers of all peers, including our own - peers: List[Deque[T]], - lock: ContextManager[Any], -) -> Generator[T, None, None]: - """An individual iterator of a :py:func:`~.tee`""" - try: - while True: - if not buffer: - with lock: - # Another peer produced an item while we were waiting for the lock. - # Proceed with the next loop iteration to yield the item. - if buffer: - continue - try: - item = next(iterator) - except StopIteration: - break - else: - # Append to all buffers, including our own. We'll fetch our - # item from the buffer again, instead of yielding it directly. - # This ensures the proper item ordering if any of our peers - # are fetching items concurrently. They may have buffered their - # item already. - for peer_buffer in peers: - peer_buffer.append(item) - yield buffer.popleft() - finally: - with lock: - # this peer is done – remove its buffer - for idx, peer_buffer in enumerate(peers): # pragma: no branch - if peer_buffer is buffer: - peers.pop(idx) - break - # if we are the last peer, try and close the iterator - if not peers and hasattr(iterator, "close"): - iterator.close() - - -class Tee(Generic[T]): - """ - Create ``n`` separate asynchronous iterators over ``iterable`` - - This splits a single ``iterable`` into multiple iterators, each providing - the same items in the same order. - All child iterators may advance separately but share the same items - from ``iterable`` -- when the most advanced iterator retrieves an item, - it is buffered until the least advanced iterator has yielded it as well. - A ``tee`` works lazily and can handle an infinite ``iterable``, provided - that all iterators advance. - - .. code-block:: python3 - - async def derivative(sensor_data): - previous, current = a.tee(sensor_data, n=2) - await a.anext(previous) # advance one iterator - return a.map(operator.sub, previous, current) - - Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead - of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked - to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method - immediately closes all children, and it can be used in an ``async with`` context - for the same effect. - - If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* - provide these items. Also, ``tee`` must internally buffer each item until the - last iterator has yielded it; if the most and least advanced iterator differ - by most data, using a :py:class:`list` is more efficient (but not lazy). - - If the underlying iterable is concurrency safe (``anext`` may be awaited - concurrently) the resulting iterators are concurrency safe as well. Otherwise, - the iterators are safe if there is only ever one single "most advanced" iterator. - To enforce sequential use of ``anext``, provide a ``lock`` - - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - - and access is automatically synchronised. - """ - - def __init__( - self, - iterable: Iterator[T], - n: int = 2, - *, - lock: Optional[ContextManager[Any]] = None, - ): - self._iterator = iter(iterable) - self._buffers: List[Deque[T]] = [deque() for _ in range(n)] - self._children = tuple( - tee_peer( - iterator=self._iterator, - buffer=buffer, - peers=self._buffers, - lock=lock if lock is not None else NoLock(), - ) - for buffer in self._buffers - ) - - def __len__(self) -> int: - return len(self._children) - - @overload - def __getitem__(self, item: int) -> Iterator[T]: - ... - - @overload - def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: - ... - - def __getitem__( - self, item: Union[int, slice] - ) -> Union[Iterator[T], Tuple[Iterator[T], ...]]: - return self._children[item] - - def __iter__(self) -> Iterator[Iterator[T]]: - yield from self._children - - def __enter__(self) -> "Tee[T]": - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: - self.close() - return False - - def close(self) -> None: - for child in self._children: - child.close() - - -# Why this is needed https://stackoverflow.com/a/44638570 -safetee = Tee - - -def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: - """Utility batching function.""" - it = iter(iterable) - while True: - chunk = list(islice(it, size)) - if not chunk: - return - yield chunk +__all__ = ["NoLock", "tee_peer", "Tee", "batch_iterate"] diff --git a/libs/langchain/langchain/utils/loading.py b/libs/langchain/langchain/utils/loading.py index 60f3e3cf7d4..b048d383eeb 100644 --- a/libs/langchain/langchain/utils/loading.py +++ b/libs/langchain/langchain/utils/loading.py @@ -1,54 +1,3 @@ -"""Utilities for loading configurations from langchain-hub.""" +from langchain_core.utils.loading import try_load_from_hub -import os -import re -import tempfile -from pathlib import Path, PurePosixPath -from typing import Any, Callable, Optional, Set, TypeVar, Union -from urllib.parse import urljoin - -import requests - -DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") -URL_BASE = os.environ.get( - "LANGCHAIN_HUB_URL_BASE", - "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", -) -HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") - -T = TypeVar("T") - - -def try_load_from_hub( - path: Union[str, Path], - loader: Callable[[str], T], - valid_prefix: str, - valid_suffixes: Set[str], - **kwargs: Any, -) -> Optional[T]: - """Load configuration from hub. Returns None if path is not a hub path.""" - if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): - return None - ref, remote_path_str = match.groups() - ref = ref[1:] if ref else DEFAULT_REF - remote_path = Path(remote_path_str) - if remote_path.parts[0] != valid_prefix: - return None - if remote_path.suffix[1:] not in valid_suffixes: - raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") - - # Using Path with URLs is not recommended, because on Windows - # the backslash is used as the path separator, which can cause issues - # when working with URLs that use forward slashes as the path separator. - # Instead, use PurePosixPath to ensure that forward slashes are used as the - # path separator, regardless of the operating system. - full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) - - r = requests.get(full_url, timeout=5) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = Path(tmpdirname) / remote_path.name - with open(file, "wb") as f: - f.write(r.content) - return loader(str(file), **kwargs) +__all__ = ["try_load_from_hub"] diff --git a/libs/langchain/langchain/utils/openai_functions.py b/libs/langchain/langchain/utils/openai_functions.py index 02a57a1ce92..6380d197649 100644 --- a/libs/langchain/langchain/utils/openai_functions.py +++ b/libs/langchain/langchain/utils/openai_functions.py @@ -1,6 +1,7 @@ from typing import Literal, Optional, Type, TypedDict -from langchain.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel + from langchain.utils.json_schema import dereference_refs diff --git a/libs/langchain/langchain/utils/pydantic.py b/libs/langchain/langchain/utils/pydantic.py index 80ddb81fcb9..a5000f231e8 100644 --- a/libs/langchain/langchain/utils/pydantic.py +++ b/libs/langchain/langchain/utils/pydantic.py @@ -1,14 +1,3 @@ -"""Utilities for tests.""" +from langchain_core.utils.pydantic import get_pydantic_major_version - -def get_pydantic_major_version() -> int: - """Get the major version of Pydantic.""" - try: - import pydantic - - return int(pydantic.__version__.split(".")[0]) - except ImportError: - return 0 - - -PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() +__all__ = ["get_pydantic_major_version"] diff --git a/libs/langchain/langchain/utils/utils.py b/libs/langchain/langchain/utils/utils.py index ece5f6aa1be..57629433164 100644 --- a/libs/langchain/langchain/utils/utils.py +++ b/libs/langchain/langchain/utils/utils.py @@ -1,180 +1,21 @@ -"""Generic utility functions.""" -import contextlib -import datetime -import functools -import importlib -import warnings -from importlib.metadata import version -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union +from langchain_core.utils.utils import ( + build_extra_kwargs, + check_package_version, + convert_to_secret_str, + get_pydantic_field_names, + guard_import, + mock_now, + raise_for_status_with_text, + xor_args, +) -from packaging.version import parse -from requests import HTTPError, Response - -from langchain.pydantic_v1 import SecretStr - - -def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: - """Validate specified keyword args are mutually exclusive.""" - - def decorator(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - """Validate exactly one arg in each group is not None.""" - counts = [ - sum(1 for arg in arg_group if kwargs.get(arg) is not None) - for arg_group in arg_groups - ] - invalid_groups = [i for i, count in enumerate(counts) if count != 1] - if invalid_groups: - invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] - raise ValueError( - "Exactly one argument in each of the following" - " groups must be defined:" - f" {', '.join(invalid_group_names)}" - ) - return func(*args, **kwargs) - - return wrapper - - return decorator - - -def raise_for_status_with_text(response: Response) -> None: - """Raise an error with the response text.""" - try: - response.raise_for_status() - except HTTPError as e: - raise ValueError(response.text) from e - - -@contextlib.contextmanager -def mock_now(dt_value): # type: ignore - """Context manager for mocking out datetime.now() in unit tests. - - Example: - with mock_now(datetime.datetime(2011, 2, 3, 10, 11)): - assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11) - """ - - class MockDateTime(datetime.datetime): - """Mock datetime.datetime.now() with a fixed datetime.""" - - @classmethod - def now(cls): # type: ignore - # Create a copy of dt_value. - return datetime.datetime( - dt_value.year, - dt_value.month, - dt_value.day, - dt_value.hour, - dt_value.minute, - dt_value.second, - dt_value.microsecond, - dt_value.tzinfo, - ) - - real_datetime = datetime.datetime - datetime.datetime = MockDateTime - try: - yield datetime.datetime - finally: - datetime.datetime = real_datetime - - -def guard_import( - module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None -) -> Any: - """Dynamically imports a module and raises a helpful exception if the module is not - installed.""" - try: - module = importlib.import_module(module_name, package) - except ImportError: - raise ImportError( - f"Could not import {module_name} python package. " - f"Please install it with `pip install {pip_name or module_name}`." - ) - return module - - -def check_package_version( - package: str, - lt_version: Optional[str] = None, - lte_version: Optional[str] = None, - gt_version: Optional[str] = None, - gte_version: Optional[str] = None, -) -> None: - """Check the version of a package.""" - imported_version = parse(version(package)) - if lt_version is not None and imported_version >= parse(lt_version): - raise ValueError( - f"Expected {package} version to be < {lt_version}. Received " - f"{imported_version}." - ) - if lte_version is not None and imported_version > parse(lte_version): - raise ValueError( - f"Expected {package} version to be <= {lte_version}. Received " - f"{imported_version}." - ) - if gt_version is not None and imported_version <= parse(gt_version): - raise ValueError( - f"Expected {package} version to be > {gt_version}. Received " - f"{imported_version}." - ) - if gte_version is not None and imported_version < parse(gte_version): - raise ValueError( - f"Expected {package} version to be >= {gte_version}. Received " - f"{imported_version}." - ) - - -def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: - """Get field names, including aliases, for a pydantic class. - - Args: - pydantic_cls: Pydantic class.""" - all_required_field_names = set() - for field in pydantic_cls.__fields__.values(): - all_required_field_names.add(field.name) - if field.has_alias: - all_required_field_names.add(field.alias) - return all_required_field_names - - -def build_extra_kwargs( - extra_kwargs: Dict[str, Any], - values: Dict[str, Any], - all_required_field_names: Set[str], -) -> Dict[str, Any]: - """Build extra kwargs from values and extra_kwargs. - - Args: - extra_kwargs: Extra kwargs passed in by user. - values: Values passed in by user. - all_required_field_names: All required field names for the pydantic class. - """ - for field_name in list(values): - if field_name in extra_kwargs: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra_kwargs[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - return extra_kwargs - - -def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) +__all__ = [ + "xor_args", + "raise_for_status_with_text", + "mock_now", + "guard_import", + "check_package_version", + "get_pydantic_field_names", + "build_extra_kwargs", + "convert_to_secret_str", +] diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 4a1c6dd9696..d91d125d0b1 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -21,7 +21,7 @@ and retrieve the data that are 'most similar' to the embedded query. from typing import Any -from langchain.schema.vectorstore import VectorStore +from langchain_core.schema.vectorstore import VectorStore def _import_alibaba_cloud_open_search() -> Any: diff --git a/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py b/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py index 6740c1d4003..f1ed358684b 100644 --- a/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py +++ b/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py @@ -4,9 +4,9 @@ import numbers from hashlib import sha1 from typing import Any, Dict, Iterable, List, Optional, Tuple -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore logger = logging.getLogger() diff --git a/libs/langchain/langchain/vectorstores/analyticdb.py b/libs/langchain/langchain/vectorstores/analyticdb.py index 1792ff6be65..c27ee0ac98e 100644 --- a/libs/langchain/langchain/vectorstores/analyticdb.py +++ b/libs/langchain/langchain/vectorstores/analyticdb.py @@ -12,9 +12,10 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env _LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536 diff --git a/libs/langchain/langchain/vectorstores/annoy.py b/libs/langchain/langchain/vectorstores/annoy.py index 975c0062145..4054285a43b 100644 --- a/libs/langchain/langchain/vectorstores/annoy.py +++ b/libs/langchain/langchain/vectorstores/annoy.py @@ -8,12 +8,12 @@ from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance INDEX_METRICS = frozenset(["angular", "euclidean", "manhattan", "hamming", "dot"]) diff --git a/libs/langchain/langchain/vectorstores/astradb.py b/libs/langchain/langchain/vectorstores/astradb.py index d1c25e02f46..f9a428ea9c4 100644 --- a/libs/langchain/langchain/vectorstores/astradb.py +++ b/libs/langchain/langchain/vectorstores/astradb.py @@ -17,11 +17,11 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils.iter import batch_iterate from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils.iter import batch_iterate from langchain.vectorstores.utils import maximal_marginal_relevance ADBVST = TypeVar("ADBVST", bound="AstraDB") diff --git a/libs/langchain/langchain/vectorstores/atlas.py b/libs/langchain/langchain/vectorstores/atlas.py index 15541afccb3..230a123a7b5 100644 --- a/libs/langchain/langchain/vectorstores/atlas.py +++ b/libs/langchain/langchain/vectorstores/atlas.py @@ -5,10 +5,10 @@ import uuid from typing import Any, Iterable, List, Optional, Type import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/awadb.py b/libs/langchain/langchain/vectorstores/awadb.py index 258dbed75a9..79cf3d072ea 100644 --- a/libs/langchain/langchain/vectorstores/awadb.py +++ b/libs/langchain/langchain/vectorstores/awadb.py @@ -5,10 +5,10 @@ import uuid from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/azure_cosmos_db.py b/libs/langchain/langchain/vectorstores/azure_cosmos_db.py index d1816e160b8..4002eaf4db0 100644 --- a/libs/langchain/langchain/vectorstores/azure_cosmos_db.py +++ b/libs/langchain/langchain/vectorstores/azure_cosmos_db.py @@ -22,10 +22,9 @@ from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: + from langchain_core.schema.embeddings import Embeddings from pymongo.collection import Collection - from langchain.schema.embeddings import Embeddings - # Before Python 3.11 native StrEnum is not available class CosmosDBSimilarityType(str, Enum): diff --git a/libs/langchain/langchain/vectorstores/azuresearch.py b/libs/langchain/langchain/vectorstores/azuresearch.py index aa083f46856..4fbd766bc84 100644 --- a/libs/langchain/langchain/vectorstores/azuresearch.py +++ b/libs/langchain/langchain/vectorstores/azuresearch.py @@ -17,16 +17,16 @@ from typing import ( ) import numpy as np +from langchain_core.pydantic_v1 import root_validator +from langchain_core.schema import BaseRetriever +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain.docstore.document import Document -from langchain.pydantic_v1 import root_validator -from langchain.schema import BaseRetriever -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env logger = logging.getLogger() diff --git a/libs/langchain/langchain/vectorstores/bageldb.py b/libs/langchain/langchain/vectorstores/bageldb.py index fbf5df78ef9..870673a12dd 100644 --- a/libs/langchain/langchain/vectorstores/bageldb.py +++ b/libs/langchain/langchain/vectorstores/bageldb.py @@ -18,10 +18,11 @@ if TYPE_CHECKING: import bagel.config from bagel.api.types import ID, OneOrMany, Where, WhereDocument +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils import xor_args + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils import xor_args DEFAULT_K = 5 diff --git a/libs/langchain/langchain/vectorstores/baiducloud_vector_search.py b/libs/langchain/langchain/vectorstores/baiducloud_vector_search.py index e73055d7f80..a66b1104486 100644 --- a/libs/langchain/langchain/vectorstores/baiducloud_vector_search.py +++ b/libs/langchain/langchain/vectorstores/baiducloud_vector_search.py @@ -12,9 +12,10 @@ from typing import ( Union, ) +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: from elasticsearch import Elasticsearch diff --git a/libs/langchain/langchain/vectorstores/base.py b/libs/langchain/langchain/vectorstores/base.py index 05e90ef5f04..5be4e018853 100644 --- a/libs/langchain/langchain/vectorstores/base.py +++ b/libs/langchain/langchain/vectorstores/base.py @@ -1,3 +1,3 @@ -from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever +from langchain_core.schema.vectorstore import VectorStore, VectorStoreRetriever __all__ = ["VectorStore", "VectorStoreRetriever"] diff --git a/libs/langchain/langchain/vectorstores/cassandra.py b/libs/langchain/langchain/vectorstores/cassandra.py index d57c05cf86f..194a89f2783 100644 --- a/libs/langchain/langchain/vectorstores/cassandra.py +++ b/libs/langchain/langchain/vectorstores/cassandra.py @@ -20,9 +20,10 @@ import numpy as np if typing.TYPE_CHECKING: from cassandra.cluster import Session +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance CVST = TypeVar("CVST", bound="Cassandra") diff --git a/libs/langchain/langchain/vectorstores/chroma.py b/libs/langchain/langchain/vectorstores/chroma.py index 9f77447aa3f..476712b0783 100644 --- a/libs/langchain/langchain/vectorstores/chroma.py +++ b/libs/langchain/langchain/vectorstores/chroma.py @@ -16,11 +16,11 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils import xor_args from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils import xor_args from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/clarifai.py b/libs/langchain/langchain/vectorstores/clarifai.py index 1d92c7e5cf9..132b39c7d5d 100644 --- a/libs/langchain/langchain/vectorstores/clarifai.py +++ b/libs/langchain/langchain/vectorstores/clarifai.py @@ -7,10 +7,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Iterable, List, Optional, Tuple import requests +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/clickhouse.py b/libs/langchain/langchain/vectorstores/clickhouse.py index e63704e09ab..d833ef42c0f 100644 --- a/libs/langchain/langchain/vectorstores/clickhouse.py +++ b/libs/langchain/langchain/vectorstores/clickhouse.py @@ -6,10 +6,11 @@ from hashlib import sha1 from threading import Thread from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from langchain_core.pydantic_v1 import BaseSettings +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseSettings -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger() diff --git a/libs/langchain/langchain/vectorstores/dashvector.py b/libs/langchain/langchain/vectorstores/dashvector.py index 51b8ed7ff31..e151798a15d 100644 --- a/libs/langchain/langchain/vectorstores/dashvector.py +++ b/libs/langchain/langchain/vectorstores/dashvector.py @@ -11,10 +11,10 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env from langchain.vectorstores.utils import maximal_marginal_relevance diff --git a/libs/langchain/langchain/vectorstores/deeplake.py b/libs/langchain/langchain/vectorstores/deeplake.py index f8fa6a8fc89..d76667d6aaa 100644 --- a/libs/langchain/langchain/vectorstores/deeplake.py +++ b/libs/langchain/langchain/vectorstores/deeplake.py @@ -14,9 +14,10 @@ try: except ImportError: _DEEPLAKE_INSTALLED = False +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/dingo.py b/libs/langchain/langchain/vectorstores/dingo.py index e83a5161d39..5e78b956da5 100644 --- a/libs/langchain/langchain/vectorstores/dingo.py +++ b/libs/langchain/langchain/vectorstores/dingo.py @@ -5,10 +5,10 @@ import uuid from typing import Any, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/docarray/base.py b/libs/langchain/langchain/vectorstores/docarray/base.py index ffdbbbed456..6518e418b88 100644 --- a/libs/langchain/langchain/vectorstores/docarray/base.py +++ b/libs/langchain/langchain/vectorstores/docarray/base.py @@ -2,11 +2,11 @@ from abc import ABC from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type import numpy as np +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore -from langchain.pydantic_v1 import Field -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/docarray/hnsw.py b/libs/langchain/langchain/vectorstores/docarray/hnsw.py index b8d7a44475a..5fa44dd0eee 100644 --- a/libs/langchain/langchain/vectorstores/docarray/hnsw.py +++ b/libs/langchain/langchain/vectorstores/docarray/hnsw.py @@ -2,7 +2,8 @@ from __future__ import annotations from typing import Any, List, Literal, Optional -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings + from langchain.vectorstores.docarray.base import ( DocArrayIndex, _check_docarray_import, diff --git a/libs/langchain/langchain/vectorstores/docarray/in_memory.py b/libs/langchain/langchain/vectorstores/docarray/in_memory.py index 43602b2acaa..78346abae96 100644 --- a/libs/langchain/langchain/vectorstores/docarray/in_memory.py +++ b/libs/langchain/langchain/vectorstores/docarray/in_memory.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings + from langchain.vectorstores.docarray.base import ( DocArrayIndex, _check_docarray_import, diff --git a/libs/langchain/langchain/vectorstores/elastic_vector_search.py b/libs/langchain/langchain/vectorstores/elastic_vector_search.py index 48c492c22a3..dd5ef4aef14 100644 --- a/libs/langchain/langchain/vectorstores/elastic_vector_search.py +++ b/libs/langchain/langchain/vectorstores/elastic_vector_search.py @@ -14,10 +14,11 @@ from typing import ( Union, ) -from langchain._api import deprecated +from langchain_core._api import deprecated +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/elasticsearch.py b/libs/langchain/langchain/vectorstores/elasticsearch.py index a635cdea40a..4d972381c5f 100644 --- a/libs/langchain/langchain/vectorstores/elasticsearch.py +++ b/libs/langchain/langchain/vectorstores/elasticsearch.py @@ -15,10 +15,10 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/epsilla.py b/libs/langchain/langchain/vectorstores/epsilla.py index 94521513cea..bb11305b84f 100644 --- a/libs/langchain/langchain/vectorstores/epsilla.py +++ b/libs/langchain/langchain/vectorstores/epsilla.py @@ -5,9 +5,10 @@ import logging import uuid from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: from pyepsilla import vectordb diff --git a/libs/langchain/langchain/vectorstores/faiss.py b/libs/langchain/langchain/vectorstores/faiss.py index 8c81a1789bc..39430fec178 100644 --- a/libs/langchain/langchain/vectorstores/faiss.py +++ b/libs/langchain/langchain/vectorstores/faiss.py @@ -22,12 +22,12 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/hippo.py b/libs/langchain/langchain/vectorstores/hippo.py index d8043a98c8e..59490e48040 100644 --- a/libs/langchain/langchain/vectorstores/hippo.py +++ b/libs/langchain/langchain/vectorstores/hippo.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: from transwarp_hippo_api.hippo_client import HippoClient diff --git a/libs/langchain/langchain/vectorstores/hologres.py b/libs/langchain/langchain/vectorstores/hologres.py index d925c7a4c1d..93504adfb41 100644 --- a/libs/langchain/langchain/vectorstores/hologres.py +++ b/libs/langchain/langchain/vectorstores/hologres.py @@ -5,9 +5,10 @@ import logging import uuid from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env ADA_TOKEN_COUNT = 1536 diff --git a/libs/langchain/langchain/vectorstores/lancedb.py b/libs/langchain/langchain/vectorstores/lancedb.py index c396ef648bb..4e795ef7df1 100644 --- a/libs/langchain/langchain/vectorstores/lancedb.py +++ b/libs/langchain/langchain/vectorstores/lancedb.py @@ -3,9 +3,10 @@ from __future__ import annotations import uuid from typing import Any, Iterable, List, Optional +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore class LanceDB(VectorStore): diff --git a/libs/langchain/langchain/vectorstores/llm_rails.py b/libs/langchain/langchain/vectorstores/llm_rails.py index 23ed41ad1ec..50e25f922f0 100644 --- a/libs/langchain/langchain/vectorstores/llm_rails.py +++ b/libs/langchain/langchain/vectorstores/llm_rails.py @@ -8,10 +8,10 @@ import uuid from typing import Any, Iterable, List, Optional, Tuple import requests +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings -from langchain.pydantic_v1 import Field -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.base import VectorStore, VectorStoreRetriever diff --git a/libs/langchain/langchain/vectorstores/marqo.py b/libs/langchain/langchain/vectorstores/marqo.py index 49b1d19aa6f..1d0e9d288c3 100644 --- a/libs/langchain/langchain/vectorstores/marqo.py +++ b/libs/langchain/langchain/vectorstores/marqo.py @@ -15,9 +15,10 @@ from typing import ( Union, ) +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: import marqo diff --git a/libs/langchain/langchain/vectorstores/matching_engine.py b/libs/langchain/langchain/vectorstores/matching_engine.py index 630be56df3f..9b5d8b42704 100644 --- a/libs/langchain/langchain/vectorstores/matching_engine.py +++ b/libs/langchain/langchain/vectorstores/matching_engine.py @@ -6,9 +6,10 @@ import time import uuid from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type -from langchain.schema.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore +from langchain_core.schema.document import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.utilities.vertexai import get_client_info if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/meilisearch.py b/libs/langchain/langchain/vectorstores/meilisearch.py index e0e8c2846d8..80f24241522 100644 --- a/libs/langchain/langchain/vectorstores/meilisearch.py +++ b/libs/langchain/langchain/vectorstores/meilisearch.py @@ -3,9 +3,10 @@ from __future__ import annotations import uuid from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/milvus.py b/libs/langchain/langchain/vectorstores/milvus.py index fc10852e4de..83556848341 100644 --- a/libs/langchain/langchain/vectorstores/milvus.py +++ b/libs/langchain/langchain/vectorstores/milvus.py @@ -5,10 +5,10 @@ from typing import Any, Iterable, List, Optional, Tuple, Union from uuid import uuid4 import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/momento_vector_index.py b/libs/langchain/langchain/vectorstores/momento_vector_index.py index e79d48c46f6..db30d9251c7 100644 --- a/libs/langchain/langchain/vectorstores/momento_vector_index.py +++ b/libs/langchain/langchain/vectorstores/momento_vector_index.py @@ -11,9 +11,10 @@ from typing import ( ) from uuid import uuid4 +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/langchain/vectorstores/mongodb_atlas.py b/libs/langchain/langchain/vectorstores/mongodb_atlas.py index b7caec575fd..872fa7e5757 100644 --- a/libs/langchain/langchain/vectorstores/mongodb_atlas.py +++ b/libs/langchain/langchain/vectorstores/mongodb_atlas.py @@ -15,10 +15,10 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/myscale.py b/libs/langchain/langchain/vectorstores/myscale.py index 9dbc6ae40a9..c57ca4c2a5c 100644 --- a/libs/langchain/langchain/vectorstores/myscale.py +++ b/libs/langchain/langchain/vectorstores/myscale.py @@ -6,10 +6,11 @@ from hashlib import sha1 from threading import Thread from typing import Any, Dict, Iterable, List, Optional, Tuple +from langchain_core.pydantic_v1 import BaseSettings +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseSettings -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger() diff --git a/libs/langchain/langchain/vectorstores/neo4j_vector.py b/libs/langchain/langchain/vectorstores/neo4j_vector.py index e826a502212..595656083cd 100644 --- a/libs/langchain/langchain/vectorstores/neo4j_vector.py +++ b/libs/langchain/langchain/vectorstores/neo4j_vector.py @@ -15,9 +15,10 @@ from typing import ( Type, ) +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/langchain/vectorstores/nucliadb.py b/libs/langchain/langchain/vectorstores/nucliadb.py index de4537ca4ec..0a649c7114a 100644 --- a/libs/langchain/langchain/vectorstores/nucliadb.py +++ b/libs/langchain/langchain/vectorstores/nucliadb.py @@ -1,9 +1,9 @@ import os from typing import Any, Dict, Iterable, List, Optional, Type -from langchain.schema.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VST, VectorStore +from langchain_core.schema.document import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VST, VectorStore FIELD_TYPES = { "f": "files", diff --git a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py index 19241e476e6..b2cc55c7e8c 100644 --- a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py +++ b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py @@ -5,10 +5,10 @@ import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env from langchain.vectorstores.utils import maximal_marginal_relevance diff --git a/libs/langchain/langchain/vectorstores/pgembedding.py b/libs/langchain/langchain/vectorstores/pgembedding.py index 0d4562f96ad..db2f67b927a 100644 --- a/libs/langchain/langchain/vectorstores/pgembedding.py +++ b/libs/langchain/langchain/vectorstores/pgembedding.py @@ -14,9 +14,10 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env Base = declarative_base() # type: Any @@ -401,7 +402,7 @@ class PGEmbedding(VectorStore): page_content=result.EmbeddingStore.document, metadata=result.EmbeddingStore.cmetadata, ), - result.distance if self.embedding_function is not None else None, + result.distance if self.embedding_function is not None else 0.0, ) for result in results ] diff --git a/libs/langchain/langchain/vectorstores/pgvecto_rs.py b/libs/langchain/langchain/vectorstores/pgvecto_rs.py index 2471395295e..ff3c01477cf 100644 --- a/libs/langchain/langchain/vectorstores/pgvecto_rs.py +++ b/libs/langchain/langchain/vectorstores/pgvecto_rs.py @@ -5,15 +5,14 @@ from typing import Any, Iterable, List, Literal, Optional, Tuple, Type import numpy as np import sqlalchemy +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from sqlalchemy import insert, select from sqlalchemy.dialects import postgresql from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm.session import Session -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore - class _ORMBase(DeclarativeBase): __tablename__: str diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index 41deb51f10a..db8122e0f67 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -30,9 +30,10 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env from langchain.vectorstores.utils import maximal_marginal_relevance diff --git a/libs/langchain/langchain/vectorstores/pinecone.py b/libs/langchain/langchain/vectorstores/pinecone.py index b39ab9c0814..e7337cf726e 100644 --- a/libs/langchain/langchain/vectorstores/pinecone.py +++ b/libs/langchain/langchain/vectorstores/pinecone.py @@ -6,11 +6,11 @@ import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Union import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils.iter import batch_iterate from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils.iter import batch_iterate from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index b85edabcdc2..97f0e340c9f 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -23,10 +23,10 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index 940b3c9ce1a..13156ccc138 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -22,12 +22,12 @@ from typing import ( import numpy as np import yaml +from langchain_core._api import deprecated +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore, VectorStoreRetriever -from langchain._api import deprecated from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever from langchain.utilities.redis import ( _array_to_buffer, _buffer_to_array, diff --git a/libs/langchain/langchain/vectorstores/redis/schema.py b/libs/langchain/langchain/vectorstores/redis/schema.py index 5419e7ba991..e175b8111cf 100644 --- a/libs/langchain/langchain/vectorstores/redis/schema.py +++ b/libs/langchain/langchain/vectorstores/redis/schema.py @@ -7,9 +7,9 @@ from typing import Any, Dict, List, Optional, Union import numpy as np import yaml +from langchain_core.pydantic_v1 import BaseModel, Field, validator from typing_extensions import TYPE_CHECKING, Literal -from langchain.pydantic_v1 import BaseModel, Field, validator from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/rocksetdb.py b/libs/langchain/langchain/vectorstores/rocksetdb.py index a5fb5bda671..cae5b6de2d6 100644 --- a/libs/langchain/langchain/vectorstores/rocksetdb.py +++ b/libs/langchain/langchain/vectorstores/rocksetdb.py @@ -4,9 +4,10 @@ import logging from enum import Enum from typing import Any, Iterable, List, Optional, Tuple +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/scann.py b/libs/langchain/langchain/vectorstores/scann.py index 9e730314085..999b1c992bf 100644 --- a/libs/langchain/langchain/vectorstores/scann.py +++ b/libs/langchain/langchain/vectorstores/scann.py @@ -7,12 +7,12 @@ from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/langchain/vectorstores/semadb.py b/libs/langchain/langchain/vectorstores/semadb.py index dd89138c2ef..c7aeab3150c 100644 --- a/libs/langchain/langchain/vectorstores/semadb.py +++ b/libs/langchain/langchain/vectorstores/semadb.py @@ -3,10 +3,10 @@ from uuid import uuid4 import numpy as np import requests +from langchain_core.schema.document import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore -from langchain.schema.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/langchain/vectorstores/singlestoredb.py b/libs/langchain/langchain/vectorstores/singlestoredb.py index 070eeb7b0c0..27f06533031 100644 --- a/libs/langchain/langchain/vectorstores/singlestoredb.py +++ b/libs/langchain/langchain/vectorstores/singlestoredb.py @@ -12,11 +12,11 @@ from typing import ( Type, ) +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore, VectorStoreRetriever from sqlalchemy.pool import QueuePool from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever from langchain.vectorstores.utils import DistanceStrategy DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.DOT_PRODUCT diff --git a/libs/langchain/langchain/vectorstores/sklearn.py b/libs/langchain/langchain/vectorstores/sklearn.py index 32224e9b10d..34aea3ff5bb 100644 --- a/libs/langchain/langchain/vectorstores/sklearn.py +++ b/libs/langchain/langchain/vectorstores/sklearn.py @@ -10,10 +10,11 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type from uuid import uuid4 +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils import guard_import + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils import guard_import from langchain.vectorstores.utils import maximal_marginal_relevance DEFAULT_K = 4 # Number of Documents to return. diff --git a/libs/langchain/langchain/vectorstores/sqlitevss.py b/libs/langchain/langchain/vectorstores/sqlitevss.py index fcc4157b261..0216c279d4e 100644 --- a/libs/langchain/langchain/vectorstores/sqlitevss.py +++ b/libs/langchain/langchain/vectorstores/sqlitevss.py @@ -13,9 +13,10 @@ from typing import ( Type, ) +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: import sqlite3 diff --git a/libs/langchain/langchain/vectorstores/starrocks.py b/libs/langchain/langchain/vectorstores/starrocks.py index 04dbec591ee..4bd64e4f457 100644 --- a/libs/langchain/langchain/vectorstores/starrocks.py +++ b/libs/langchain/langchain/vectorstores/starrocks.py @@ -6,10 +6,11 @@ from hashlib import sha1 from threading import Thread from typing import Any, Dict, Iterable, List, Optional, Tuple +from langchain_core.pydantic_v1 import BaseSettings +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.pydantic_v1 import BaseSettings -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore logger = logging.getLogger() DEBUG = False diff --git a/libs/langchain/langchain/vectorstores/supabase.py b/libs/langchain/langchain/vectorstores/supabase.py index f5dbdad9af8..d37880eac01 100644 --- a/libs/langchain/langchain/vectorstores/supabase.py +++ b/libs/langchain/langchain/vectorstores/supabase.py @@ -15,10 +15,10 @@ from typing import ( ) import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: @@ -46,7 +46,7 @@ class SupabaseVectorStore(VectorStore): .. code-block:: python from langchain.embeddings.openai import OpenAIEmbeddings - from langchain.schema import Document + from langchain_core.schema import Document from langchain.vectorstores import SupabaseVectorStore from supabase.client import create_client diff --git a/libs/langchain/langchain/vectorstores/tair.py b/libs/langchain/langchain/vectorstores/tair.py index 0e9d29023aa..75a86ec8e6c 100644 --- a/libs/langchain/langchain/vectorstores/tair.py +++ b/libs/langchain/langchain/vectorstores/tair.py @@ -5,9 +5,10 @@ import logging import uuid from typing import Any, Iterable, List, Optional, Type +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/tencentvectordb.py b/libs/langchain/langchain/vectorstores/tencentvectordb.py index fe77390c72a..cc5cc947169 100644 --- a/libs/langchain/langchain/vectorstores/tencentvectordb.py +++ b/libs/langchain/langchain/vectorstores/tencentvectordb.py @@ -7,11 +7,11 @@ import time from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore +from langchain_core.utils import guard_import from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore -from langchain.utils import guard_import from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/tigris.py b/libs/langchain/langchain/vectorstores/tigris.py index 7f9eee0b579..e168909c07c 100644 --- a/libs/langchain/langchain/vectorstores/tigris.py +++ b/libs/langchain/langchain/vectorstores/tigris.py @@ -3,9 +3,9 @@ from __future__ import annotations import itertools from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore if TYPE_CHECKING: from tigrisdb import TigrisClient diff --git a/libs/langchain/langchain/vectorstores/tiledb.py b/libs/langchain/langchain/vectorstores/tiledb.py index 11094ec4b83..8b144265b3e 100644 --- a/libs/langchain/langchain/vectorstores/tiledb.py +++ b/libs/langchain/langchain/vectorstores/tiledb.py @@ -7,10 +7,10 @@ import sys from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance INDEX_METRICS = frozenset(["euclidean"]) diff --git a/libs/langchain/langchain/vectorstores/timescalevector.py b/libs/langchain/langchain/vectorstores/timescalevector.py index 01e2590b5ec..b7fcbe77ca8 100644 --- a/libs/langchain/langchain/vectorstores/timescalevector.py +++ b/libs/langchain/langchain/vectorstores/timescalevector.py @@ -18,9 +18,10 @@ from typing import ( Union, ) -from langchain.schema.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore +from langchain_core.schema.document import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.utils import get_from_dict_or_env from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/langchain/vectorstores/typesense.py b/libs/langchain/langchain/vectorstores/typesense.py index 622cdb5a4eb..bb37192670f 100644 --- a/libs/langchain/langchain/vectorstores/typesense.py +++ b/libs/langchain/langchain/vectorstores/typesense.py @@ -3,9 +3,10 @@ from __future__ import annotations import uuid from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_env if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/usearch.py b/libs/langchain/langchain/vectorstores/usearch.py index f44fc31ba60..4c5a23533d5 100644 --- a/libs/langchain/langchain/vectorstores/usearch.py +++ b/libs/langchain/langchain/vectorstores/usearch.py @@ -3,12 +3,12 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore def dependable_usearch_import() -> Any: diff --git a/libs/langchain/langchain/vectorstores/vald.py b/libs/langchain/langchain/vectorstores/vald.py index 560515a5272..d7ca76a004e 100644 --- a/libs/langchain/langchain/vectorstores/vald.py +++ b/libs/langchain/langchain/vectorstores/vald.py @@ -4,10 +4,10 @@ from __future__ import annotations from typing import Any, Iterable, List, Optional, Tuple, Type import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance diff --git a/libs/langchain/langchain/vectorstores/vearch.py b/libs/langchain/langchain/vectorstores/vearch.py index 11cede24e10..67b2a1e84d8 100644 --- a/libs/langchain/langchain/vectorstores/vearch.py +++ b/libs/langchain/langchain/vectorstores/vearch.py @@ -6,10 +6,10 @@ import uuid from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: import vearch diff --git a/libs/langchain/langchain/vectorstores/vectara.py b/libs/langchain/langchain/vectorstores/vectara.py index 221d36716a8..85f92d9fe6a 100644 --- a/libs/langchain/langchain/vectorstores/vectara.py +++ b/libs/langchain/langchain/vectorstores/vectara.py @@ -7,11 +7,10 @@ from hashlib import md5 from typing import Any, Iterable, List, Optional, Tuple, Type import requests - -from langchain.pydantic_v1 import Field -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore, VectorStoreRetriever +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore, VectorStoreRetriever logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/vectorstores/vespa.py b/libs/langchain/langchain/vectorstores/vespa.py index ae0a7ce7647..31e6cf5c60e 100644 --- a/libs/langchain/langchain/vectorstores/vespa.py +++ b/libs/langchain/langchain/vectorstores/vespa.py @@ -2,8 +2,9 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from langchain_core.schema.embeddings import Embeddings + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.base import VectorStore, VectorStoreRetriever diff --git a/libs/langchain/langchain/vectorstores/weaviate.py b/libs/langchain/langchain/vectorstores/weaviate.py index 341a3a4d560..9d33dd5df4c 100644 --- a/libs/langchain/langchain/vectorstores/weaviate.py +++ b/libs/langchain/langchain/vectorstores/weaviate.py @@ -15,10 +15,10 @@ from typing import ( from uuid import uuid4 import numpy as np +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: diff --git a/libs/langchain/langchain/vectorstores/xata.py b/libs/langchain/langchain/vectorstores/xata.py index d031a21e792..eec1b22dec3 100644 --- a/libs/langchain/langchain/vectorstores/xata.py +++ b/libs/langchain/langchain/vectorstores/xata.py @@ -4,9 +4,10 @@ import time from itertools import repeat from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore class XataVectorStore(VectorStore): diff --git a/libs/langchain/langchain/vectorstores/zep.py b/libs/langchain/langchain/vectorstores/zep.py index 7872b1f9254..395c6670b1e 100644 --- a/libs/langchain/langchain/vectorstores/zep.py +++ b/libs/langchain/langchain/vectorstores/zep.py @@ -5,9 +5,10 @@ import warnings from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore + from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore if TYPE_CHECKING: from zep_python.document import Document as ZepDocument diff --git a/libs/langchain/langchain/vectorstores/zilliz.py b/libs/langchain/langchain/vectorstores/zilliz.py index 95757d20a67..7d45cf88436 100644 --- a/libs/langchain/langchain/vectorstores/zilliz.py +++ b/libs/langchain/langchain/vectorstores/zilliz.py @@ -3,7 +3,8 @@ from __future__ import annotations import logging from typing import Any, Dict, List, Optional -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings + from langchain.vectorstores.milvus import Milvus logger = logging.getLogger(__name__) diff --git a/libs/langchain/scripts/check_imports.sh b/libs/langchain/scripts/check_imports.sh index 80c38b30f58..c3463691a61 100755 --- a/libs/langchain/scripts/check_imports.sh +++ b/libs/langchain/scripts/check_imports.sh @@ -7,22 +7,22 @@ errors=0 # Check the conditions git grep '^from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && errors=$((errors+1)) -git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1)) -git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1)) -git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1)) -git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|env)' && errors=$((errors+1)) -git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) -git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|env)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1)) # TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there -git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1)) -git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1)) -git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1)) -git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1)) -git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1)) -git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1)) -git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1)) -git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1)) -git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1)) +git grep '^from langchain\.' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1)) # Decide on an exit status based on the errors if [ "$errors" -gt 0 ]; then diff --git a/libs/langchain/scripts/check_pydantic.sh b/libs/langchain/scripts/check_pydantic.sh index 7c2d9c5c0a3..06b5bb81ae2 100755 --- a/libs/langchain/scripts/check_pydantic.sh +++ b/libs/langchain/scripts/check_pydantic.sh @@ -20,8 +20,8 @@ result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') if [ -n "$result" ]; then echo "ERROR: The following lines need to be updated:" echo "$result" - echo "Please replace the code with an import from langchain.pydantic_v1." + echo "Please replace the code with an import from langchain_core.pydantic_v1." echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain.pydantic_v1 import BaseModel'" + echo "with 'from langchain_core.pydantic_v1 import BaseModel'" exit 1 fi diff --git a/libs/langchain/tests/integration_tests/cache/test_cassandra.py b/libs/langchain/tests/integration_tests/cache/test_cassandra.py index 9b2186c13b1..60700c73290 100644 --- a/libs/langchain/tests/integration_tests/cache/test_cassandra.py +++ b/libs/langchain/tests/integration_tests/cache/test_cassandra.py @@ -4,10 +4,10 @@ import time from typing import Any, Iterator, Tuple import pytest +from langchain_core.schema import Generation, LLMResult from langchain.cache import CassandraCache, CassandraSemanticCache from langchain.globals import get_llm_cache, set_llm_cache -from langchain.schema import Generation, LLMResult from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/cache/test_gptcache.py b/libs/langchain/tests/integration_tests/cache/test_gptcache.py index 0bdd5b51096..12d0430807c 100644 --- a/libs/langchain/tests/integration_tests/cache/test_gptcache.py +++ b/libs/langchain/tests/integration_tests/cache/test_gptcache.py @@ -2,10 +2,10 @@ import os from typing import Any, Callable, Union import pytest +from langchain_core.schema import Generation from langchain.cache import GPTCache from langchain.globals import get_llm_cache, set_llm_cache -from langchain.schema import Generation from tests.unit_tests.llms.fake_llm import FakeLLM try: diff --git a/libs/langchain/tests/integration_tests/cache/test_momento_cache.py b/libs/langchain/tests/integration_tests/cache/test_momento_cache.py index a4cc7fb21d4..ca551b62548 100644 --- a/libs/langchain/tests/integration_tests/cache/test_momento_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_momento_cache.py @@ -11,10 +11,10 @@ from datetime import timedelta from typing import Iterator import pytest +from langchain_core.schema import Generation, LLMResult from langchain.cache import MomentoCache from langchain.globals import set_llm_cache -from langchain.schema import Generation, LLMResult from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py index 8005d68d8d0..a02670d3b6f 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_redis_cache.py @@ -3,14 +3,14 @@ import uuid from typing import List, cast import pytest +from langchain_core.load.dump import dumps +from langchain_core.schema import Generation, LLMResult +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.schema.output import ChatGeneration from langchain.cache import RedisCache, RedisSemanticCache from langchain.globals import get_llm_cache, set_llm_cache -from langchain.load.dump import dumps -from langchain.schema import Generation, LLMResult -from langchain.schema.embeddings import Embeddings -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage -from langchain.schema.output import ChatGeneration from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py b/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py index 1a78d0ae85f..1ed5ba98b2a 100644 --- a/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py +++ b/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py @@ -2,10 +2,10 @@ import uuid import pytest +from langchain_core.schema import Generation, LLMResult import langchain from langchain.cache import UpstashRedisCache -from langchain.schema import Generation, LLMResult from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py b/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py index e84eae5aa57..4a26ff733d0 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -4,6 +4,7 @@ import os import pytest from aiohttp import ClientSession +from langchain_core.prompts import PromptTemplate from langchain.agents import AgentType, initialize_agent, load_tools from langchain.callbacks import tracing_enabled @@ -17,7 +18,6 @@ from langchain.chains.constitutional_ai.base import ConstitutionalChain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI -from langchain.prompts import PromptTemplate questions = [ ( diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py index 906e12c7041..96f70e67d0f 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py @@ -2,13 +2,13 @@ from typing import List import pytest +from langchain_core.schema import ChatGeneration, LLMResult +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.callbacks.manager import CallbackManager from langchain.chat_models.anthropic import ( ChatAnthropic, ) -from langchain.schema import ChatGeneration, LLMResult -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py index 137c35342be..8f5565483c9 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_azure_openai.py @@ -3,15 +3,15 @@ import os from typing import Any import pytest - -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models import AzureChatOpenAI -from langchain.schema import ( +from langchain_core.schema import ( ChatGeneration, ChatResult, LLMResult, ) -from langchain.schema.messages import BaseMessage, HumanMessage +from langchain_core.schema.messages import BaseMessage, HumanMessage + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models import AzureChatOpenAI from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "") diff --git a/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py b/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py index 5b1e351aa74..8050eb3dc0d 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_azureml_endpoint.py @@ -1,10 +1,6 @@ """Test AzureML Chat Endpoint wrapper.""" -from langchain.chat_models.azureml_endpoint import ( - AzureMLChatOnlineEndpoint, - LlamaContentFormatter, -) -from langchain.schema import ( +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, @@ -12,6 +8,11 @@ from langchain.schema import ( LLMResult, ) +from langchain.chat_models.azureml_endpoint import ( + AzureMLChatOnlineEndpoint, + LlamaContentFormatter, +) + def test_llama_call() -> None: """Test valid call to Open Source Foundation Model.""" diff --git a/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py b/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py index 58b5dd5aa28..0dfbd1dabc9 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py @@ -1,5 +1,6 @@ +from langchain_core.schema.messages import AIMessage, HumanMessage + from langchain.chat_models.baichuan import ChatBaichuan -from langchain.schema.messages import AIMessage, HumanMessage def test_chat_baichuan() -> None: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py b/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py index 1c93efa3b94..e85fb3815fb 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_bedrock.py @@ -2,11 +2,11 @@ from typing import Any import pytest +from langchain_core.schema import ChatGeneration, LLMResult +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage from langchain.callbacks.manager import CallbackManager from langchain.chat_models import BedrockChat -from langchain.schema import ChatGeneration, LLMResult -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py index 4b79f40d093..214f55922e0 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_ernie.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_ernie.py @@ -1,7 +1,7 @@ import pytest +from langchain_core.schema.messages import AIMessage, HumanMessage from langchain.chat_models.ernie import ErnieBotChat -from langchain.schema.messages import AIMessage, HumanMessage def test_chat_ernie_bot() -> None: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index 43657cdae3c..b8d96f004a9 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -3,10 +3,10 @@ import sys from typing import cast import pytest +from langchain_core.schema import ChatGeneration, ChatResult, LLMResult +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage from langchain.chat_models.fireworks import ChatFireworks -from langchain.schema import ChatGeneration, ChatResult, LLMResult -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage if sys.version_info < (3, 9): pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py b/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py index 09ded60064d..efb4320bc29 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_google_palm.py @@ -5,14 +5,14 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to """ import pytest - -from langchain.chat_models import ChatGooglePalm -from langchain.schema import ( +from langchain_core.schema import ( ChatGeneration, ChatResult, LLMResult, ) -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage + +from langchain.chat_models import ChatGooglePalm def test_chat_google_palm() -> None: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py b/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py index 4024994584c..59b0cc6362f 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_hunyuan.py @@ -1,5 +1,6 @@ +from langchain_core.schema.messages import AIMessage, HumanMessage + from langchain.chat_models.hunyuan import ChatHunyuan -from langchain.schema.messages import AIMessage, HumanMessage def test_chat_hunyuan() -> None: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py b/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py index f4100a21381..fee3b53cadb 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py @@ -2,16 +2,16 @@ import pytest - -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.jinachat import JinaChat -from langchain.schema import ( +from langchain_core.schema import ( BaseMessage, ChatGeneration, HumanMessage, LLMResult, SystemMessage, ) + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.jinachat import JinaChat from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_konko.py b/libs/langchain/tests/integration_tests/chat_models/test_konko.py index c47bbbb3f0c..ff49751d970 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_konko.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_konko.py @@ -2,15 +2,15 @@ from typing import Any import pytest - -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.konko import ChatKonko -from langchain.schema import ( +from langchain_core.schema import ( ChatGeneration, ChatResult, LLMResult, ) -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.konko import ChatKonko from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_litellm.py b/libs/langchain/tests/integration_tests/chat_models/test_litellm.py index 4f252453da6..ba380a90b81 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_litellm.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_litellm.py @@ -1,15 +1,16 @@ """Test Anthropic API wrapper.""" from typing import List +from langchain_core.schema import ( + ChatGeneration, + LLMResult, +) +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage + from langchain.callbacks.manager import ( CallbackManager, ) from langchain.chat_models.litellm import ChatLiteLLM -from langchain.schema import ( - ChatGeneration, - LLMResult, -) -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index ae1d3acab85..47c9a8e2d5c 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -2,6 +2,15 @@ from typing import Any, List, Optional, Union import pytest +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import ( + ChatGeneration, + ChatResult, + LLMResult, +) +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk from langchain.callbacks.base import AsyncCallbackHandler from langchain.callbacks.manager import CallbackManager @@ -10,15 +19,6 @@ from langchain.chains.openai_functions import ( ) from langchain.chat_models.openai import ChatOpenAI from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import ( - ChatGeneration, - ChatResult, - LLMResult, -) -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage -from langchain.schema.output import ChatGenerationChunk, GenerationChunk from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py b/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py index 93fb7b3b31f..a2f519ce88a 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_pai_eas_chat_endpoint.py @@ -1,15 +1,16 @@ """Test AliCloud Pai Eas Chat Model.""" import os -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint -from langchain.schema import ( +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, HumanMessage, LLMResult, ) + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py index 3e5e9f8850a..6e311482637 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_promptlayer_openai.py @@ -1,15 +1,15 @@ """Test PromptLayerChatOpenAI wrapper.""" import pytest - -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI -from langchain.schema import ( +from langchain_core.schema import ( ChatGeneration, ChatResult, LLMResult, ) -from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.schema.messages import BaseMessage, HumanMessage, SystemMessage + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py index 3c82ed10ef1..8a4547ff6ae 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py @@ -2,13 +2,8 @@ from typing import Any -from langchain.callbacks.manager import CallbackManager -from langchain.chains.openai_functions import ( - create_openai_fn_chain, -) -from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import ( +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, @@ -16,6 +11,12 @@ from langchain.schema import ( HumanMessage, LLMResult, ) + +from langchain.callbacks.manager import CallbackManager +from langchain.chains.openai_functions import ( + create_openai_fn_chain, +) +from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler _FUNCTIONS: Any = [ diff --git a/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py b/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py index afb00e1d91d..42519c248db 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_tongyi.py @@ -1,14 +1,15 @@ """Test Alibaba Tongyi Chat Model.""" -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.tongyi import ChatTongyi -from langchain.schema import ( +from langchain_core.schema import ( AIMessage, BaseMessage, ChatGeneration, HumanMessage, LLMResult, ) + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.tongyi import ChatTongyi from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py index bb7dab0875b..351e3a9f4de 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py @@ -11,11 +11,11 @@ from typing import Optional from unittest.mock import MagicMock, Mock, patch import pytest +from langchain_core.schema import LLMResult +from langchain_core.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.chat_models import ChatVertexAI from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples -from langchain.schema import LLMResult -from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage @pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"]) diff --git a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_docai.py b/libs/langchain/tests/integration_tests/document_loaders/parsers/test_docai.py index 3fbabcd4a71..66c2352277d 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/parsers/test_docai.py +++ b/libs/langchain/tests/integration_tests/document_loaders/parsers/test_docai.py @@ -6,9 +6,10 @@ https://cloud.google.com/document-ai/docs/setup """ import os +from langchain_core.schema import Document + from langchain.document_loaders.blob_loaders import Blob from langchain.document_loaders.parsers import DocAIParser -from langchain.schema import Document def test_docai_parser() -> None: diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py b/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py index fbd5cf45fd8..f3f44032835 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_arxiv.py @@ -1,9 +1,9 @@ from typing import List import pytest +from langchain_core.schema import Document from langchain.document_loaders.arxiv import ArxivLoader -from langchain.schema import Document def assert_docs(docs: List[Document]) -> None: diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py index 2cb5070627c..74e91bcb05f 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_dataframe.py @@ -1,8 +1,8 @@ import pandas as pd import pytest +from langchain_core.schema import Document from langchain.document_loaders import DataFrameLoader -from langchain.schema import Document @pytest.fixture diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py index 4b0680b0f11..b4e6d45dca4 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_geodataframe.py @@ -3,9 +3,9 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +from langchain_core.schema import Document from langchain.document_loaders import GeoDataFrameLoader -from langchain.schema import Document if TYPE_CHECKING: from geopandas import GeoDataFrame diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py index 2858b41e8e3..f1743b451da 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py @@ -3,9 +3,9 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +from langchain_core.schema import Document from langchain.document_loaders import PolarsDataFrameLoader -from langchain.schema import Document if TYPE_CHECKING: import polars as pl diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py b/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py index 6b58cceda8e..9f2da3e934b 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_pubmed.py @@ -2,9 +2,9 @@ from typing import List import pytest +from langchain_core.schema import Document from langchain.document_loaders import PubMedLoader -from langchain.schema import Document xmltodict = pytest.importorskip("xmltodict") diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_quip.py b/libs/langchain/tests/integration_tests/document_loaders/test_quip.py index 4e62dea6c5e..8e059714ce5 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_quip.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_quip.py @@ -2,9 +2,9 @@ from typing import Dict from unittest.mock import MagicMock, patch import pytest +from langchain_core.schema import Document from langchain.document_loaders.quip import QuipLoader -from langchain.schema import Document try: from quip_api.quip import QuipClient # noqa: F401 diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py b/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py index d36e91a2169..b69bebd0afe 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py @@ -4,10 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +from langchain_core.pydantic_v1 import ValidationError +from langchain_core.schema.document import Document from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader -from langchain.pydantic_v1 import ValidationError -from langchain.schema.document import Document if TYPE_CHECKING: import tensorflow as tf # noqa: E402 diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py b/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py index a83df608274..dce596cb97d 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_xorbits.py @@ -1,7 +1,7 @@ import pytest +from langchain_core.schema import Document from langchain.document_loaders import XorbitsLoader -from langchain.schema import Document try: import xorbits # noqa: F401 diff --git a/libs/langchain/tests/integration_tests/llms/test_anthropic.py b/libs/langchain/tests/integration_tests/llms/test_anthropic.py index f68053b2aa8..f95d35000ca 100644 --- a/libs/langchain/tests/integration_tests/llms/test_anthropic.py +++ b/libs/langchain/tests/integration_tests/llms/test_anthropic.py @@ -2,10 +2,10 @@ from typing import Generator import pytest +from langchain_core.schema import LLMResult from langchain.callbacks.manager import CallbackManager from langchain.llms.anthropic import Anthropic -from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/integration_tests/llms/test_azure_openai.py b/libs/langchain/tests/integration_tests/llms/test_azure_openai.py index 5ad6eed2e3a..3c66c1b36be 100644 --- a/libs/langchain/tests/integration_tests/llms/test_azure_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_azure_openai.py @@ -3,12 +3,12 @@ import os from typing import Any, Generator import pytest +from langchain_core.schema import ( + LLMResult, +) from langchain.callbacks.manager import CallbackManager from langchain.llms import AzureOpenAI -from langchain.schema import ( - LLMResult, -) from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "") diff --git a/libs/langchain/tests/integration_tests/llms/test_chatglm.py b/libs/langchain/tests/integration_tests/llms/test_chatglm.py index a62a76896db..fca5ca34ebc 100644 --- a/libs/langchain/tests/integration_tests/llms/test_chatglm.py +++ b/libs/langchain/tests/integration_tests/llms/test_chatglm.py @@ -1,6 +1,7 @@ """Test ChatGLM API wrapper.""" +from langchain_core.schema import LLMResult + from langchain.llms.chatglm import ChatGLM -from langchain.schema import LLMResult def test_chatglm_call() -> None: diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index 40f5cf1cc5f..9abf445bf21 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -3,15 +3,15 @@ import sys from typing import Generator import pytest - -from langchain.chains import LLMChain -from langchain.llms.fireworks import Fireworks -from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ( +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) -from langchain.schema import LLMResult +from langchain_core.schema import LLMResult + +from langchain.chains import LLMChain +from langchain.llms.fireworks import Fireworks if sys.version_info < (3, 9): pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) diff --git a/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py b/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py index 9efde28caf5..9fd25778093 100644 --- a/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py +++ b/libs/langchain/tests/integration_tests/llms/test_opaqueprompts.py @@ -1,11 +1,12 @@ +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnableParallel +from langchain_core.schema.output_parser import StrOutputParser + import langchain.utilities.opaqueprompts as op from langchain.chains.llm import LLMChain from langchain.llms import OpenAI from langchain.llms.opaqueprompts import OpaquePrompts from langchain.memory import ConversationBufferWindowMemory -from langchain.prompts import PromptTemplate -from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import RunnableParallel prompt_template = """ As an AI assistant, you will answer questions according to given context. diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 07e24f9a86e..a0f2981620b 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -3,12 +3,12 @@ from pathlib import Path from typing import Generator import pytest +from langchain_core.schema import LLMResult from langchain.callbacks.manager import CallbackManager from langchain.chat_models.openai import ChatOpenAI from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI -from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( FakeCallbackHandler, ) diff --git a/libs/langchain/tests/integration_tests/llms/test_qianfan_endpoint.py b/libs/langchain/tests/integration_tests/llms/test_qianfan_endpoint.py index 75f47444c88..9c87dd95e5f 100644 --- a/libs/langchain/tests/integration_tests/llms/test_qianfan_endpoint.py +++ b/libs/langchain/tests/integration_tests/llms/test_qianfan_endpoint.py @@ -2,9 +2,9 @@ from typing import Generator import pytest +from langchain_core.schema import LLMResult from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint -from langchain.schema import LLMResult def test_call() -> None: diff --git a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py b/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py index 97761676847..a068e08aca4 100644 --- a/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py +++ b/libs/langchain/tests/integration_tests/llms/test_symblai_nebula.py @@ -1,7 +1,8 @@ """Test Nebula API wrapper.""" +from langchain_core.prompts.prompt import PromptTemplate + from langchain.chains.llm import LLMChain from langchain.llms.symblai_nebula import Nebula -from langchain.prompts.prompt import PromptTemplate def test_symblai_nebula_call() -> None: diff --git a/libs/langchain/tests/integration_tests/llms/test_tongyi.py b/libs/langchain/tests/integration_tests/llms/test_tongyi.py index de37ff9861a..704d994a524 100644 --- a/libs/langchain/tests/integration_tests/llms/test_tongyi.py +++ b/libs/langchain/tests/integration_tests/llms/test_tongyi.py @@ -1,6 +1,7 @@ """Test Tongyi API wrapper.""" +from langchain_core.schema import LLMResult + from langchain.llms.tongyi import Tongyi -from langchain.schema import LLMResult def test_tongyi_call() -> None: diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index 85e8ca7a261..fdf55c9383b 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -10,12 +10,12 @@ Your end-user credentials would be used to make the calls (make sure you've run import os import pytest +from langchain_core.schema import LLMResult from pytest_mock import MockerFixture from langchain.chains.summarize import load_summarize_chain from langchain.docstore.document import Document from langchain.llms import VertexAI, VertexAIModelGarden -from langchain.schema import LLMResult def test_vertex_initialization() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py b/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py index e2ffd7ac78f..157dd7bc3d7 100644 --- a/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py +++ b/libs/langchain/tests/integration_tests/memory/chat_message_histories/test_zep.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING import pytest +from langchain_core.schema.messages import AIMessage, HumanMessage, SystemMessage from pytest_mock import MockerFixture from langchain.memory.chat_message_histories import ZepChatMessageHistory -from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage if TYPE_CHECKING: from zep_python import ZepClient diff --git a/libs/langchain/tests/integration_tests/memory/test_cassandra.py b/libs/langchain/tests/integration_tests/memory/test_cassandra.py index 3e6572f58b7..d10e3ee5a36 100644 --- a/libs/langchain/tests/integration_tests/memory/test_cassandra.py +++ b/libs/langchain/tests/integration_tests/memory/test_cassandra.py @@ -3,12 +3,12 @@ import time from typing import Optional from cassandra.cluster import Cluster +from langchain_core.schema.messages import AIMessage, HumanMessage from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) -from langchain.schema.messages import AIMessage, HumanMessage def _chat_message_history( diff --git a/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py b/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py index 0a32883351b..ea927c97f36 100644 --- a/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py +++ b/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py @@ -1,9 +1,10 @@ import json import os +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory -from langchain.schema.messages import _message_to_dict # Replace these with your Azure Cosmos DB endpoint and key endpoint = os.environ.get("COSMOS_DB_ENDPOINT", "") diff --git a/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py b/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py index 6ea68d679dd..46694dfee17 100644 --- a/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py +++ b/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py @@ -4,10 +4,10 @@ import uuid from typing import Generator, Union import pytest +from langchain_core.schema.messages import _message_to_dict from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory -from langchain.schema.messages import _message_to_dict """ cd tests/integration_tests/memory/docker-compose diff --git a/libs/langchain/tests/integration_tests/memory/test_firestore.py b/libs/langchain/tests/integration_tests/memory/test_firestore.py index 4dcff90f505..b75802be848 100644 --- a/libs/langchain/tests/integration_tests/memory/test_firestore.py +++ b/libs/langchain/tests/integration_tests/memory/test_firestore.py @@ -1,8 +1,9 @@ import json +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import FirestoreChatMessageHistory -from langchain.schema.messages import _message_to_dict def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_momento.py b/libs/langchain/tests/integration_tests/memory/test_momento.py index 0e20c4d183b..2cb928d8f36 100644 --- a/libs/langchain/tests/integration_tests/memory/test_momento.py +++ b/libs/langchain/tests/integration_tests/memory/test_momento.py @@ -10,10 +10,10 @@ from datetime import timedelta from typing import Iterator import pytest +from langchain_core.schema.messages import _message_to_dict from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import MomentoChatMessageHistory -from langchain.schema.messages import _message_to_dict def random_string() -> str: diff --git a/libs/langchain/tests/integration_tests/memory/test_mongodb.py b/libs/langchain/tests/integration_tests/memory/test_mongodb.py index cb7cdbe7352..6fb7c1a8b2b 100644 --- a/libs/langchain/tests/integration_tests/memory/test_mongodb.py +++ b/libs/langchain/tests/integration_tests/memory/test_mongodb.py @@ -1,9 +1,10 @@ import json import os +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import MongoDBChatMessageHistory -from langchain.schema.messages import _message_to_dict # Replace these with your mongodb connection string connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "") diff --git a/libs/langchain/tests/integration_tests/memory/test_neo4j.py b/libs/langchain/tests/integration_tests/memory/test_neo4j.py index d14e2c81f25..9ee5d3072c2 100644 --- a/libs/langchain/tests/integration_tests/memory/test_neo4j.py +++ b/libs/langchain/tests/integration_tests/memory/test_neo4j.py @@ -1,8 +1,9 @@ import json +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import Neo4jChatMessageHistory -from langchain.schema.messages import _message_to_dict def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_redis.py b/libs/langchain/tests/integration_tests/memory/test_redis.py index 547f6ab3c21..308ad0e7d56 100644 --- a/libs/langchain/tests/integration_tests/memory/test_redis.py +++ b/libs/langchain/tests/integration_tests/memory/test_redis.py @@ -1,8 +1,9 @@ import json +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import RedisChatMessageHistory -from langchain.schema.messages import _message_to_dict def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_rockset.py b/libs/langchain/tests/integration_tests/memory/test_rockset.py index 2d6ef5dbdba..2817aefe3e0 100644 --- a/libs/langchain/tests/integration_tests/memory/test_rockset.py +++ b/libs/langchain/tests/integration_tests/memory/test_rockset.py @@ -8,9 +8,10 @@ and ROCKSET_REGION environment variables set. import json import os +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import RocksetChatMessageHistory -from langchain.schema.messages import _message_to_dict collection_name = "langchain_demo" session_id = "MySession" diff --git a/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py b/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py index 92a611a5292..b41c7837703 100644 --- a/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py +++ b/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py @@ -1,7 +1,8 @@ import json +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory -from langchain.schema.messages import _message_to_dict # Replace these with your mongodb connection string TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db" diff --git a/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py b/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py index ff9b9a445ea..dfc2746b581 100644 --- a/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py +++ b/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py @@ -1,12 +1,12 @@ import json import pytest +from langchain_core.schema.messages import _message_to_dict from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories.upstash_redis import ( UpstashRedisChatMessageHistory, ) -from langchain.schema.messages import _message_to_dict URL = "" TOKEN = "" diff --git a/libs/langchain/tests/integration_tests/memory/test_xata.py b/libs/langchain/tests/integration_tests/memory/test_xata.py index 88bd158a257..7b74142fd33 100644 --- a/libs/langchain/tests/integration_tests/memory/test_xata.py +++ b/libs/langchain/tests/integration_tests/memory/test_xata.py @@ -6,9 +6,10 @@ Before running this test, please create a Xata database. import json import os +from langchain_core.schema.messages import _message_to_dict + from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import XataChatMessageHistory -from langchain.schema.messages import _message_to_dict class TestXata: diff --git a/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py b/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py index 5c7bd4b140a..61c976e4faa 100644 --- a/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py +++ b/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py @@ -1,12 +1,12 @@ """Test functionality related to ngram overlap based selector.""" import pytest +from langchain_core.prompts.prompt import PromptTemplate from langchain.prompts.example_selector.ngram_overlap import ( NGramOverlapExampleSelector, ngram_overlap_score, ) -from langchain.prompts.prompt import PromptTemplate EXAMPLES = [ {"input": "See Spot run.", "output": "foo1"}, diff --git a/libs/langchain/tests/integration_tests/retrievers/docarray/fixtures.py b/libs/langchain/tests/integration_tests/retrievers/docarray/fixtures.py index 951a38a6657..2639f980ed9 100644 --- a/libs/langchain/tests/integration_tests/retrievers/docarray/fixtures.py +++ b/libs/langchain/tests/integration_tests/retrievers/docarray/fixtures.py @@ -5,8 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple import numpy as np import pytest - -from langchain.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field if TYPE_CHECKING: from docarray.index import ( diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py index 389d4d04e7e..709378b8e9c 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py @@ -1,11 +1,12 @@ """Integration test for compression pipelines.""" +from langchain_core.schema import Document + from langchain.document_transformers import EmbeddingsRedundantFilter from langchain.embeddings import OpenAIEmbeddings from langchain.retrievers.document_compressors import ( DocumentCompressorPipeline, EmbeddingsFilter, ) -from langchain.schema import Document from langchain.text_splitter import CharacterTextSplitter diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py index 7434f665dc8..4c03bfa5bcb 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -1,7 +1,8 @@ """Integration test for LLMChainExtractor.""" +from langchain_core.schema import Document + from langchain.chat_models import ChatOpenAI from langchain.retrievers.document_compressors import LLMChainExtractor -from langchain.schema import Document def test_llm_construction_with_kwargs() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py index 1068a1e65a2..4891a56e5ea 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py @@ -1,7 +1,8 @@ """Integration test for llm-based relevant doc filtering.""" +from langchain_core.schema import Document + from langchain.chat_models import ChatOpenAI from langchain.retrievers.document_compressors import LLMChainFilter -from langchain.schema import Document def test_llm_chain_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py index ad2f71d7bf2..ae45cf5cf73 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py @@ -1,12 +1,12 @@ """Integration test for embedding-based relevant doc filtering.""" import numpy as np +from langchain_core.schema import Document from langchain.document_transformers.embeddings_redundant_filter import ( _DocumentWithState, ) from langchain.embeddings import OpenAIEmbeddings from langchain.retrievers.document_compressors import EmbeddingsFilter -from langchain.schema import Document def test_embeddings_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/test_arxiv.py b/libs/langchain/tests/integration_tests/retrievers/test_arxiv.py index f112c3e94c3..b2f557a7d8d 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_arxiv.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_arxiv.py @@ -2,9 +2,9 @@ from typing import List import pytest +from langchain_core.schema import Document from langchain.retrievers import ArxivRetriever -from langchain.schema import Document @pytest.fixture diff --git a/libs/langchain/tests/integration_tests/retrievers/test_azure_cognitive_search.py b/libs/langchain/tests/integration_tests/retrievers/test_azure_cognitive_search.py index effa1b79321..0c7af90ee49 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_azure_cognitive_search.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_azure_cognitive_search.py @@ -1,8 +1,8 @@ """Test Azure Cognitive Search wrapper.""" import pytest +from langchain_core.schema import Document from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever -from langchain.schema import Document def test_azure_cognitive_search_get_relevant_documents() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py b/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py index 490c22e3668..08de3761242 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py @@ -1,8 +1,9 @@ """Test Google Cloud Document AI Warehouse retriever.""" import os +from langchain_core.schema import Document + from langchain.retrievers import GoogleDocumentAIWarehouseRetriever -from langchain.schema import Document def test_google_documentai_warehoure_retriever() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py b/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py index 940c4c27d4b..3bc2937b5cf 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py @@ -14,13 +14,13 @@ export DATA_STORE_ID=... - the ID of the search engine to use for the test import os import pytest +from langchain_core.schema import Document from langchain.retrievers.google_vertex_ai_search import ( GoogleCloudEnterpriseSearchRetriever, GoogleVertexAIMultiTurnSearchRetriever, GoogleVertexAISearchRetriever, ) -from langchain.schema import Document @pytest.mark.requires("google_api_core") diff --git a/libs/langchain/tests/integration_tests/retrievers/test_kay.py b/libs/langchain/tests/integration_tests/retrievers/test_kay.py index 84c511fd6f8..73754b1c7b9 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_kay.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_kay.py @@ -1,8 +1,8 @@ """Integration test for Kay.ai API Wrapper.""" import pytest +from langchain_core.schema import Document from langchain.retrievers import KayAiRetriever -from langchain.schema import Document @pytest.mark.requires("kay") diff --git a/libs/langchain/tests/integration_tests/retrievers/test_pubmed.py b/libs/langchain/tests/integration_tests/retrievers/test_pubmed.py index 7819c03e0ed..c0a5a37b9c6 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_pubmed.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_pubmed.py @@ -2,9 +2,9 @@ from typing import List import pytest +from langchain_core.schema import Document from langchain.retrievers import PubMedRetriever -from langchain.schema import Document @pytest.fixture diff --git a/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py b/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py index e2d831c7517..2188e438707 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py @@ -2,9 +2,9 @@ from typing import List import pytest +from langchain_core.schema import Document from langchain.retrievers import WikipediaRetriever -from langchain.schema import Document @pytest.fixture diff --git a/libs/langchain/tests/integration_tests/retrievers/test_zep.py b/libs/langchain/tests/integration_tests/retrievers/test_zep.py index ff6f7e4eacf..f989c66951a 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_zep.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_zep.py @@ -4,10 +4,10 @@ import copy from typing import TYPE_CHECKING, List import pytest +from langchain_core.schema import Document from pytest_mock import MockerFixture from langchain.retrievers import ZepRetriever -from langchain.schema import Document if TYPE_CHECKING: from zep_python import MemorySearchResult, ZepClient diff --git a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py index 6010b674c49..231828d3aab 100644 --- a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py @@ -2,6 +2,8 @@ from typing import Iterator, List from uuid import uuid4 import pytest +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.schema.messages import BaseMessage, HumanMessage from langsmith import Client as Client from langsmith.schemas import DataType @@ -9,8 +11,6 @@ from langchain.chains.llm import LLMChain from langchain.chat_models import ChatOpenAI from langchain.evaluation import EvaluatorType from langchain.llms.openai import OpenAI -from langchain.prompts.chat import ChatPromptTemplate -from langchain.schema.messages import BaseMessage, HumanMessage from langchain.smith import RunEvalConfig, run_on_dataset from langchain.smith.evaluation import InputFormatError from langchain.smith.evaluation.runner_utils import arun_on_dataset diff --git a/libs/langchain/tests/integration_tests/test_document_transformers.py b/libs/langchain/tests/integration_tests/test_document_transformers.py index 1a53e67d613..63a2f1e9a46 100644 --- a/libs/langchain/tests/integration_tests/test_document_transformers.py +++ b/libs/langchain/tests/integration_tests/test_document_transformers.py @@ -1,11 +1,12 @@ """Integration test for embedding-based redundant doc filtering.""" +from langchain_core.schema import Document + from langchain.document_transformers.embeddings_redundant_filter import ( EmbeddingsClusteringFilter, EmbeddingsRedundantFilter, _DocumentWithState, ) from langchain.embeddings import OpenAIEmbeddings -from langchain.schema import Document def test_embeddings_redundant_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/test_nuclia_transformer.py b/libs/langchain/tests/integration_tests/test_nuclia_transformer.py index 55acfb8f607..6e803ac6bc1 100644 --- a/libs/langchain/tests/integration_tests/test_nuclia_transformer.py +++ b/libs/langchain/tests/integration_tests/test_nuclia_transformer.py @@ -4,9 +4,9 @@ from typing import Any from unittest import mock import pytest +from langchain_core.schema.document import Document from langchain.document_transformers.nuclia_text_transform import NucliaTextTransformer -from langchain.schema.document import Document from langchain.tools.nuclia.tool import NucliaUnderstandingAPI diff --git a/libs/langchain/tests/integration_tests/test_schema.py b/libs/langchain/tests/integration_tests/test_schema.py index 9ff2609476f..7d6d7865033 100644 --- a/libs/langchain/tests/integration_tests/test_schema.py +++ b/libs/langchain/tests/integration_tests/test_schema.py @@ -1,6 +1,6 @@ """Test formatting functionality.""" -from langchain.schema.language_model import _get_token_ids_default_method +from langchain_core.schema.language_model import _get_token_ids_default_method class TestTokenCountingWithGPT2Tokenizer: diff --git a/libs/langchain/tests/integration_tests/utilities/test_arxiv.py b/libs/langchain/tests/integration_tests/utilities/test_arxiv.py index fb1029d6984..59d1bed435c 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_arxiv.py +++ b/libs/langchain/tests/integration_tests/utilities/test_arxiv.py @@ -2,9 +2,9 @@ from typing import Any, List import pytest +from langchain_core.schema import Document from langchain.agents.load_tools import load_tools -from langchain.schema import Document from langchain.tools import ArxivQueryRun from langchain.tools.base import BaseTool from langchain.utilities import ArxivAPIWrapper diff --git a/libs/langchain/tests/integration_tests/utilities/test_pubmed.py b/libs/langchain/tests/integration_tests/utilities/test_pubmed.py index d015cb06fd9..75a74398a72 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_pubmed.py +++ b/libs/langchain/tests/integration_tests/utilities/test_pubmed.py @@ -2,9 +2,9 @@ from typing import Any, List import pytest +from langchain_core.schema import Document from langchain.agents.load_tools import load_tools -from langchain.schema import Document from langchain.tools import PubmedQueryRun from langchain.tools.base import BaseTool from langchain.utilities import PubMedAPIWrapper diff --git a/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py b/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py index 811fec0611e..a2883c46c17 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py +++ b/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py @@ -4,9 +4,9 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +from langchain_core.pydantic_v1 import ValidationError +from langchain_core.schema.document import Document -from langchain.pydantic_v1 import ValidationError -from langchain.schema.document import Document from langchain.utilities.tensorflow_datasets import TensorflowDatasets if TYPE_CHECKING: diff --git a/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py b/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py index ff5b08425e5..e4461403b3e 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py +++ b/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py @@ -2,8 +2,8 @@ from typing import List import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.utilities import WikipediaAPIWrapper diff --git a/libs/langchain/tests/integration_tests/vectorstores/conftest.py b/libs/langchain/tests/integration_tests/vectorstores/conftest.py index 507e0e1eead..2af17fa945d 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/conftest.py +++ b/libs/langchain/tests/integration_tests/vectorstores/conftest.py @@ -2,11 +2,11 @@ import os from typing import Generator, List, Union import pytest +from langchain_core.schema import Document from vcr.request import Request from langchain.document_loaders import TextLoader from langchain.embeddings import OpenAIEmbeddings -from langchain.schema import Document from langchain.text_splitter import CharacterTextSplitter # Those environment variables turn on Deep Lake pytest mode. diff --git a/libs/langchain/tests/integration_tests/vectorstores/docarray/test_hnsw.py b/libs/langchain/tests/integration_tests/vectorstores/docarray/test_hnsw.py index 0143660f126..39ef184a19c 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/docarray/test_hnsw.py +++ b/libs/langchain/tests/integration_tests/vectorstores/docarray/test_hnsw.py @@ -3,8 +3,8 @@ from typing import List import numpy as np import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores.docarray import DocArrayHnswSearch from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings diff --git a/libs/langchain/tests/integration_tests/vectorstores/docarray/test_in_memory.py b/libs/langchain/tests/integration_tests/vectorstores/docarray/test_in_memory.py index ca556b11cc5..437f9f76dd8 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/docarray/test_in_memory.py +++ b/libs/langchain/tests/integration_tests/vectorstores/docarray/test_in_memory.py @@ -3,8 +3,8 @@ from typing import List import numpy as np import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores.docarray import DocArrayInMemorySearch from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py index 87ea1edc6a0..d5d26072448 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py @@ -2,7 +2,7 @@ import math from typing import List -from langchain.schema.embeddings import Embeddings +from langchain_core.schema.embeddings import Embeddings fake_texts = ["foo", "bar", "baz"] diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py index a2c4b5eb77b..3cfe7825edf 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py @@ -2,8 +2,8 @@ import uuid from typing import Optional import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from langchain.vectorstores.qdrant import QdrantException from tests.integration_tests.vectorstores.fake_embeddings import ( diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py index 2784b4c424b..d46e71f54e6 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py @@ -1,8 +1,8 @@ from typing import Optional import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py index 55e1fc3aa30..13dd23d9b7a 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py @@ -2,8 +2,8 @@ from typing import Optional import numpy as np import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py index 052ef3c0747..315e36d3358 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_add_texts.py @@ -2,8 +2,8 @@ import uuid from typing import Optional import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py index 76fb6686555..f788e474a13 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py @@ -2,8 +2,8 @@ import uuid from typing import Callable, Optional import pytest +from langchain_core.schema.embeddings import Embeddings -from langchain.schema.embeddings import Embeddings from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py index 142b5a10e82..a9088170d95 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_from_texts.py @@ -3,8 +3,8 @@ import uuid from typing import Optional import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from langchain.vectorstores.qdrant import QdrantException from tests.integration_tests.vectorstores.fake_embeddings import ( diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py index 654bd9a4339..56a84c38db5 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py @@ -1,8 +1,8 @@ from typing import Optional import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py index b1aae06ab57..beab6b3ad8b 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py @@ -2,8 +2,8 @@ from typing import Optional import numpy as np import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_alibabacloud_opensearch.py b/libs/langchain/tests/integration_tests/vectorstores/test_alibabacloud_opensearch.py index 5e324f95707..87d43d8f2f6 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_alibabacloud_opensearch.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_alibabacloud_opensearch.py @@ -1,7 +1,8 @@ import time from typing import List -from langchain.schema import Document +from langchain_core.schema import Document + from langchain.vectorstores.alibabacloud_opensearch import ( AlibabaCloudOpenSearch, AlibabaCloudOpenSearchSettings, diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_astradb.py b/libs/langchain/tests/integration_tests/vectorstores/test_astradb.py index 25aee4f6ce0..ab4ec88076b 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_astradb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_astradb.py @@ -17,9 +17,9 @@ import os from typing import Iterable, List import pytest +from langchain_core.schema import Document from langchain.embeddings.base import Embeddings -from langchain.schema import Document from langchain.vectorstores import AstraDB # Ad-hoc embedding classes: diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_dashvector.py b/libs/langchain/tests/integration_tests/vectorstores/test_dashvector.py index e01a12073e3..23f0c3d3cc1 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_dashvector.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_dashvector.py @@ -1,6 +1,7 @@ from time import sleep -from langchain.schema import Document +from langchain_core.schema import Document + from langchain.vectorstores import DashVector from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_mongodb_atlas.py b/libs/langchain/tests/integration_tests/vectorstores/test_mongodb_atlas.py index 3bd2193ff81..be9532eab29 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_mongodb_atlas.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_mongodb_atlas.py @@ -6,9 +6,9 @@ from time import sleep from typing import Any import pytest +from langchain_core.schema.embeddings import Embeddings from langchain.docstore.document import Document -from langchain.schema.embeddings import Embeddings from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch INDEX_NAME = "langchain-test-index" diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_zep.py b/libs/langchain/tests/integration_tests/vectorstores/test_zep.py index 4ed8245b841..7db7f72f1fd 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_zep.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_zep.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from uuid import uuid4 import pytest +from langchain_core.schema import Document from pytest_mock import MockerFixture -from langchain.schema import Document from langchain.vectorstores import ZepVectorStore from langchain.vectorstores.zep import CollectionConfig diff --git a/libs/langchain/tests/mock_servers/robot/server.py b/libs/langchain/tests/mock_servers/robot/server.py index 54e32f513a6..dd332c5c07c 100644 --- a/libs/langchain/tests/mock_servers/robot/server.py +++ b/libs/langchain/tests/mock_servers/robot/server.py @@ -7,8 +7,7 @@ import uvicorn from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi - -from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field PORT = 7289 diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py index 411b57695c7..8bf376e55dd 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py @@ -1,5 +1,6 @@ +from langchain_core.schema.agent import AgentAction + from langchain.agents.format_scratchpad.log import format_log_to_str -from langchain.schema.agent import AgentAction def test_single_agent_action_observation() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py index ed7664c8b04..2648481d6ca 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py @@ -1,6 +1,7 @@ +from langchain_core.schema.agent import AgentAction +from langchain_core.schema.messages import AIMessage, HumanMessage + from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages -from langchain.schema.agent import AgentAction -from langchain.schema.messages import AIMessage, HumanMessage def test_single_intermediate_step_default_response() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py index 72227f74800..bc4350444f0 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py @@ -1,8 +1,9 @@ +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.messages import AIMessage, FunctionMessage + from langchain.agents.format_scratchpad.openai_functions import ( format_to_openai_function_messages, ) -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.messages import AIMessage, FunctionMessage def test_calls_convert_agent_action_to_messages() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py index 2509091ffd7..9c08040e905 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py @@ -1,5 +1,6 @@ +from langchain_core.schema.agent import AgentAction + from langchain.agents.format_scratchpad.xml import format_xml -from langchain.schema.agent import AgentAction def test_single_agent_action_observation() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_json.py index 49d57d47c94..e25cca3989b 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_json.py @@ -1,5 +1,6 @@ +from langchain_core.schema.agent import AgentAction, AgentFinish + from langchain.agents.output_parsers.json import JSONAgentOutputParser -from langchain.schema.agent import AgentAction, AgentFinish def test_tool_usage() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py index 613a486a4b6..53713e81730 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py @@ -1,11 +1,11 @@ import pytest +from langchain_core.schema import AgentFinish, OutputParserException +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.messages import AIMessage, SystemMessage from langchain.agents.output_parsers.openai_functions import ( OpenAIFunctionsAgentOutputParser, ) -from langchain.schema import AgentFinish, OutputParserException -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.messages import AIMessage, SystemMessage def test_not_an_ai() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py index 1a657a8ddaf..86ee4464925 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py @@ -1,7 +1,8 @@ +from langchain_core.schema.agent import AgentAction, AgentFinish + from langchain.agents.output_parsers.react_json_single_input import ( ReActJsonSingleInputOutputParser, ) -from langchain.schema.agent import AgentAction, AgentFinish def test_action() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py index 3996fc3e09b..f3cb2e56721 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py @@ -1,10 +1,10 @@ import pytest +from langchain_core.schema.agent import AgentAction, AgentFinish +from langchain_core.schema.output_parser import OutputParserException from langchain.agents.output_parsers.react_single_input import ( ReActSingleInputOutputParser, ) -from langchain.schema.agent import AgentAction, AgentFinish -from langchain.schema.output_parser import OutputParserException def test_action() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py index c3695060611..5902c5c6d9b 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py @@ -1,5 +1,6 @@ +from langchain_core.schema.agent import AgentAction, AgentFinish + from langchain.agents.output_parsers.self_ask import SelfAskOutputParser -from langchain.schema.agent import AgentAction, AgentFinish def test_follow_up() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py index 0119c931ba4..54e39339aba 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py @@ -1,5 +1,6 @@ +from langchain_core.schema.agent import AgentAction, AgentFinish + from langchain.agents.output_parsers.xml import XMLAgentOutputParser -from langchain.schema.agent import AgentAction, AgentFinish def test_tool_usage() -> None: diff --git a/libs/langchain/tests/unit_tests/agents/test_chat.py b/libs/langchain/tests/unit_tests/agents/test_chat.py index 68ad4e0a3a3..069836a5ec1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_chat.py +++ b/libs/langchain/tests/unit_tests/agents/test_chat.py @@ -1,8 +1,9 @@ """Unittests for langchain.agents.chat package.""" from typing import Tuple +from langchain_core.schema import AgentAction + from langchain.agents.chat.output_parser import ChatOutputParser -from langchain.schema import AgentAction output_parser = ChatOutputParser() diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl.py b/libs/langchain/tests/unit_tests/agents/test_mrkl.py index 0fda94fa647..b6cb39f31fe 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl.py @@ -3,13 +3,13 @@ from typing import Tuple import pytest +from langchain_core.prompts import PromptTemplate +from langchain_core.schema import AgentAction, OutputParserException from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool -from langchain.prompts import PromptTemplate -from langchain.schema import AgentAction, OutputParserException from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py b/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py index 8774451df18..78803a0e5c9 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py @@ -1,11 +1,11 @@ import pytest +from langchain_core.schema import AgentAction, AgentFinish, OutputParserException from langchain.agents.mrkl.output_parser import ( MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE, MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE, MRKLOutputParser, ) -from langchain.schema import AgentAction, AgentFinish, OutputParserException mrkl_output_parser = MRKLOutputParser() diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py index a76f790a626..e846d803e57 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py @@ -1,13 +1,13 @@ import json import pytest +from langchain_core.schema import AgentFinish, OutputParserException +from langchain_core.schema.messages import AIMessage, SystemMessage from langchain.agents.openai_functions_multi_agent.base import ( _FunctionsAgentAction, _parse_ai_message, ) -from langchain.schema import AgentFinish, OutputParserException -from langchain.schema.messages import AIMessage, SystemMessage # Test: _parse_ai_message() function. diff --git a/libs/langchain/tests/unit_tests/agents/test_react.py b/libs/langchain/tests/unit_tests/agents/test_react.py index d61e50db68e..b5b8f0be9de 100644 --- a/libs/langchain/tests/unit_tests/agents/test_react.py +++ b/libs/langchain/tests/unit_tests/agents/test_react.py @@ -2,13 +2,14 @@ from typing import Union +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import AgentAction + from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.fake import FakeListLLM -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AgentAction _PAGE_CONTENT = """This is a page about LangChain. diff --git a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py index 8e77f6be204..a739a9463cf 100644 --- a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py +++ b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py @@ -2,15 +2,16 @@ from textwrap import dedent from typing import Any, Tuple -from langchain.agents.structured_chat.base import StructuredChatAgent -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser -from langchain.prompts.chat import ( +from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AgentAction, AgentFinish -from langchain.tools import Tool +from langchain_core.schema import AgentAction, AgentFinish +from langchain_core.tool import Tool + +from langchain.agents.structured_chat.base import StructuredChatAgent +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser output_parser = StructuredChatOutputParser() diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 91f0c09bd18..7d046f310a4 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -3,9 +3,10 @@ from itertools import chain from typing import Any, Dict, List, Optional, Union from uuid import UUID +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.schema.messages import BaseMessage + from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langchain.pydantic_v1 import BaseModel -from langchain.schema.messages import BaseMessage class BaseFakeCallbackHandler(BaseModel): diff --git a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py index 32670f3c72d..5d8ce135b1f 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py @@ -3,6 +3,7 @@ from typing import List, Tuple from unittest.mock import patch import pytest +from langchain_core.schema import AgentAction, AgentFinish, LLMResult from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import ( @@ -15,7 +16,6 @@ from langchain.callbacks.manager import ( from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers from langchain.llms.openai import BaseOpenAI -from langchain.schema import AgentAction, AgentFinish, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( BaseFakeCallbackHandler, FakeAsyncCallbackHandler, diff --git a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py index 1d3d7faa8b2..a10f5cc7827 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest +from langchain_core.schema import LLMResult from langchain.callbacks import OpenAICallbackHandler from langchain.llms.openai import BaseOpenAI -from langchain.schema import LLMResult @pytest.fixture diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py index c4b03b45bdb..7c48fef6a2a 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -7,12 +7,12 @@ from uuid import uuid4 import pytest from freezegun import freeze_time +from langchain_core.schema import LLMResult +from langchain_core.schema.messages import HumanMessage from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import BaseTracer, TracerException from langchain.callbacks.tracers.schemas import Run -from langchain.schema import LLMResult -from langchain.schema.messages import HumanMessage SERIALIZED = {"id": ["llm"]} SERIALIZED_CHAT = {"id": ["chat_model"]} diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain.py index e83ddba2037..61060d656aa 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain.py @@ -6,11 +6,11 @@ from typing import Any, Dict from uuid import UUID import pytest +from langchain_core.schema.output import LLMResult from langsmith import Client from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.schemas import Run -from langchain.schema.output import LLMResult def test_example_id_assignment_threadsafe() -> None: diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index a737b98dcff..b4cc6794d37 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -7,19 +7,19 @@ from uuid import uuid4 import pytest from freezegun import freeze_time - -from langchain.callbacks.manager import CallbackManager -from langchain.callbacks.tracers.base import BaseTracer, TracerException -from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base -from langchain.schema import LLMResult -from langchain.schema.callbacks.tracers.langchain_v1 import ( +from langchain_core.callbacks.tracers.langchain_v1 import ( ChainRun, LangChainTracerV1, LLMRun, ToolRun, TracerSessionV1, ) -from langchain.schema.messages import HumanMessage +from langchain_core.schema import LLMResult +from langchain_core.schema.messages import HumanMessage + +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.tracers.base import BaseTracer, TracerException +from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base TEST_SESSION_ID = 2023 diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index d60e06a8deb..a66e2aae1bd 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -2,10 +2,10 @@ from typing import Any, Dict, List, Optional import pytest +from langchain_core.schema import RUN_KEY, BaseMemory from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.schema import RUN_KEY, BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 9bd9baf5211..62fff6269e3 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -3,6 +3,8 @@ from typing import Any, List import pytest +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import format_document from langchain.chains.combine_documents.reduce import ( collapse_docs, @@ -10,8 +12,6 @@ from langchain.chains.combine_documents.reduce import ( ) from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.docstore.document import Document -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import format_document from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index 42ebcd28d19..2eb88b5ab30 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -1,12 +1,12 @@ """Test conversation chain and memory.""" import pytest +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import BaseMemory from langchain.chains.conversation.base import ConversationChain from langchain.memory.buffer import ConversationBufferMemory from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.summary import ConversationSummaryMemory -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseMemory from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py index 038b476cf4b..cd322f1dd66 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -1,8 +1,9 @@ """Test conversation chain and memory.""" +from langchain_core.schema import Document + from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain from langchain.llms.fake import FakeListLLM from langchain.memory.buffer import ConversationBufferMemory -from langchain.schema import Document from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever diff --git a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py b/libs/langchain/tests/unit_tests/chains/test_graph_qa.py index 51c968328b2..8c034ffed0f 100644 --- a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/langchain/tests/unit_tests/chains/test_graph_qa.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List import pandas as pd +from langchain_core.prompts import PromptTemplate from langchain.chains.graph_qa.cypher import ( GraphCypherQAChain, @@ -12,7 +13,6 @@ from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_Q from langchain.graphs.graph_document import GraphDocument from langchain.graphs.graph_store import GraphStore from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain.prompts import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chains/test_hyde.py b/libs/langchain/tests/unit_tests/chains/test_hyde.py index c4becfde95d..f617f9c8c71 100644 --- a/libs/langchain/tests/unit_tests/chains/test_hyde.py +++ b/libs/langchain/tests/unit_tests/chains/test_hyde.py @@ -2,6 +2,8 @@ from typing import Any, List, Optional import numpy as np +from langchain_core.schema import Generation, LLMResult +from langchain_core.schema.embeddings import Embeddings from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -10,8 +12,6 @@ from langchain.callbacks.manager import ( from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.llms.base import BaseLLM -from langchain.schema import Generation, LLMResult -from langchain.schema.embeddings import Embeddings class FakeEmbeddings(Embeddings): diff --git a/libs/langchain/tests/unit_tests/chains/test_llm.py b/libs/langchain/tests/unit_tests/chains/test_llm.py index 54b1e58bcd1..58ede4212ea 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm.py @@ -4,11 +4,11 @@ from typing import Dict, List, Union from unittest.mock import patch import pytest +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.schema import BaseOutputParser from langchain.chains.llm import LLMChain from langchain.chains.loading import load_chain -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseOutputParser from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chains/test_memory.py b/libs/langchain/tests/unit_tests/chains/test_memory.py index d727358ca91..4a57d5a1b1e 100644 --- a/libs/langchain/tests/unit_tests/chains/test_memory.py +++ b/libs/langchain/tests/unit_tests/chains/test_memory.py @@ -1,4 +1,5 @@ import pytest +from langchain_core.schema import BaseMemory from langchain.chains.conversation.memory import ( ConversationBufferMemory, @@ -6,7 +7,6 @@ from langchain.chains.conversation.memory import ( ConversationSummaryMemory, ) from langchain.memory import ReadOnlySharedMemory, SimpleMemory -from langchain.schema import BaseMemory from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chat_loaders/test_telegram.py b/libs/langchain/tests/unit_tests/chat_loaders/test_telegram.py index c35dfbaa2b5..56df813f41d 100644 --- a/libs/langchain/tests/unit_tests/chat_loaders/test_telegram.py +++ b/libs/langchain/tests/unit_tests/chat_loaders/test_telegram.py @@ -5,9 +5,9 @@ import zipfile from typing import Sequence import pytest +from langchain_core.schema import AIMessage, BaseMessage, HumanMessage from langchain.chat_loaders import telegram, utils -from langchain.schema import AIMessage, BaseMessage, HumanMessage def _assert_messages_are_equal( diff --git a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py index d49a3f225d4..dd8e908459d 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py @@ -3,10 +3,10 @@ import os from typing import List import pytest +from langchain_core.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain.chat_models import ChatAnthropic from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic -from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage os.environ["ANTHROPIC_API_KEY"] = "foo" diff --git a/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py index de1055fc3d5..1324e31669e 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py @@ -3,10 +3,10 @@ import os import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, FixtureRequest from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint -from langchain.pydantic_v1 import SecretStr @pytest.fixture(scope="class") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py index 000771b7dc8..fe696a9df39 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py @@ -1,13 +1,6 @@ import pytest - -from langchain.chat_models.baichuan import ( - _convert_delta_to_message_chunk, - _convert_dict_to_message, - _convert_message_to_dict, - _signature, -) -from langchain.pydantic_v1 import SecretStr -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import SecretStr +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, ChatMessage, @@ -17,6 +10,13 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.chat_models.baichuan import ( + _convert_delta_to_message_chunk, + _convert_dict_to_message, + _convert_message_to_dict, + _signature, +) + def test__convert_message_to_dict_human() -> None: message = HumanMessage(content="foo") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_ernie.py b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py index a8417017f5a..472157246fc 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_ernie.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_ernie.py @@ -1,13 +1,13 @@ import pytest - -from langchain.chat_models.ernie import _convert_message_to_dict -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, FunctionMessage, HumanMessage, SystemMessage, ) +from langchain.chat_models.ernie import _convert_message_to_dict + def test__convert_dict_to_message_human() -> None: message = HumanMessage(content="foo") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py b/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py index 8db4f9653ac..5f8c130b035 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py @@ -2,10 +2,10 @@ import sys import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture from langchain.chat_models import ChatFireworks -from langchain.pydantic_v1 import SecretStr if sys.version_info < (3, 9): pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_google_palm.py b/libs/langchain/tests/unit_tests/chat_models/test_google_palm.py index 8bcb9ee78fb..d3e31d2b33d 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_google_palm.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_google_palm.py @@ -1,13 +1,13 @@ """Test Google PaLM Chat API wrapper.""" import pytest +from langchain_core.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.chat_models.google_palm import ( ChatGooglePalm, ChatGooglePalmError, _messages_to_prompt_dict, ) -from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage def test_messages_to_prompt_dict_with_valid_messages() -> None: diff --git a/libs/langchain/tests/unit_tests/chat_models/test_hunyuan.py b/libs/langchain/tests/unit_tests/chat_models/test_hunyuan.py index 5bec5eb95d6..e5b1ba6b0af 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_hunyuan.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_hunyuan.py @@ -1,13 +1,6 @@ import pytest - -from langchain.chat_models.hunyuan import ( - _convert_delta_to_message_chunk, - _convert_dict_to_message, - _convert_message_to_dict, - _signature, -) -from langchain.pydantic_v1 import SecretStr -from langchain.schema.messages import ( +from langchain_core.pydantic_v1 import SecretStr +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, ChatMessage, @@ -17,6 +10,13 @@ from langchain.schema.messages import ( SystemMessage, ) +from langchain.chat_models.hunyuan import ( + _convert_delta_to_message_chunk, + _convert_dict_to_message, + _convert_message_to_dict, + _signature, +) + def test__convert_message_to_dict_human() -> None: message = HumanMessage(content="foo") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_javelin_ai_gateway.py b/libs/langchain/tests/unit_tests/chat_models/test_javelin_ai_gateway.py index 27392fa7f04..bfba529af71 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_javelin_ai_gateway.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_javelin_ai_gateway.py @@ -1,9 +1,9 @@ """Test `Javelin AI Gateway` chat models""" import pytest +from langchain_core.pydantic_v1 import SecretStr from langchain.chat_models import ChatJavelinAIGateway -from langchain.pydantic_v1 import SecretStr @pytest.mark.requires("javelin_sdk") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_openai.py b/libs/langchain/tests/unit_tests/chat_models/test_openai.py index c2337247240..2fe5cc94bd4 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_openai.py @@ -4,16 +4,16 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest - -from langchain.adapters.openai import convert_dict_to_message -from langchain.chat_models.openai import ChatOpenAI -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessage, FunctionMessage, HumanMessage, SystemMessage, ) +from langchain.adapters.openai import convert_dict_to_message +from langchain.chat_models.openai import ChatOpenAI + @pytest.mark.requires("openai") def test_openai_model_param() -> None: diff --git a/libs/langchain/tests/unit_tests/docstore/test_arbitrary_fn.py b/libs/langchain/tests/unit_tests/docstore/test_arbitrary_fn.py index 728bfded740..2de54d6839f 100644 --- a/libs/langchain/tests/unit_tests/docstore/test_arbitrary_fn.py +++ b/libs/langchain/tests/unit_tests/docstore/test_arbitrary_fn.py @@ -1,5 +1,6 @@ +from langchain_core.schema import Document + from langchain.docstore.arbitrary_fn import DocstoreFn -from langchain.schema import Document def test_document_found() -> None: diff --git a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_generic.py b/libs/langchain/tests/unit_tests/document_loaders/parsers/test_generic.py index d06b4da47df..121339e5b1d 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/parsers/test_generic.py +++ b/libs/langchain/tests/unit_tests/document_loaders/parsers/test_generic.py @@ -3,11 +3,11 @@ from typing import Iterator import pytest +from langchain_core.schema import Document from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob from langchain.document_loaders.parsers.generic import MimeTypeBasedParser -from langchain.schema import Document class TestMimeBasedParser: diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_base.py b/libs/langchain/tests/unit_tests/document_loaders/test_base.py index 544113993c2..77df9a031eb 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_base.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_base.py @@ -1,9 +1,10 @@ """Test Base Schema of documents.""" from typing import Iterator +from langchain_core.schema import Document + from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob -from langchain.schema import Document def test_base_blob_parser() -> None: diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_generic_loader.py b/libs/langchain/tests/unit_tests/document_loaders/test_generic_loader.py index 9d6a2166247..72ba1c6edf6 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_generic_loader.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_generic_loader.py @@ -5,11 +5,11 @@ from pathlib import Path from typing import Generator, Iterator import pytest +from langchain_core.schema import Document from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob, FileSystemBlobLoader from langchain.document_loaders.generic import GenericLoader -from langchain.schema import Document @pytest.fixture diff --git a/libs/langchain/tests/unit_tests/document_transformers/test_beautiful_soup_transformer.py b/libs/langchain/tests/unit_tests/document_transformers/test_beautiful_soup_transformer.py index 8644ed2e42b..8996dbb341a 100644 --- a/libs/langchain/tests/unit_tests/document_transformers/test_beautiful_soup_transformer.py +++ b/libs/langchain/tests/unit_tests/document_transformers/test_beautiful_soup_transformer.py @@ -1,8 +1,8 @@ """Unit tests for beautiful soup document transformer.""" import pytest +from langchain_core.schema.document import Document from langchain.document_transformers import BeautifulSoupTransformer -from langchain.schema.document import Document @pytest.mark.requires("bs4") diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index bcdded666d3..7da4258ca0a 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -2,9 +2,9 @@ from typing import List import pytest +from langchain_core.schema.embeddings import Embeddings from langchain.embeddings import CacheBackedEmbeddings -from langchain.schema.embeddings import Embeddings from langchain.storage.in_memory import InMemoryStore diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 6a625821eb2..dff9f4b278e 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -3,6 +3,8 @@ from typing import Any, Dict, List, Optional, Tuple import pytest +from langchain_core.pydantic_v1 import Field +from langchain_core.schema import AgentAction, BaseMessage, OutputParserException from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.evaluation.agents.trajectory_eval_chain import ( @@ -10,8 +12,6 @@ from langchain.evaluation.agents.trajectory_eval_chain import ( TrajectoryEvalChain, TrajectoryOutputParser, ) -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, BaseMessage, OutputParserException from langchain.tools.base import tool from tests.unit_tests.llms.fake_chat_model import FakeChatModel diff --git a/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py b/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py index 24bbd115235..1fe61e33936 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py +++ b/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py @@ -1,7 +1,7 @@ import pytest +from langchain_core.schema import Document from langchain.indexes._api import _HashedDocument -from langchain.schema import Document def test_hashed_document_hashing() -> None: diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index c15c64ff8bd..ac36b8a4f38 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -14,6 +14,8 @@ from unittest.mock import patch import pytest import pytest_asyncio +from langchain_core.schema import Document +from langchain_core.schema.vectorstore import VST, VectorStore import langchain.vectorstores from langchain.document_loaders.base import BaseLoader @@ -21,8 +23,6 @@ from langchain.embeddings.base import Embeddings from langchain.indexes import aindex, index from langchain.indexes._api import _abatch from langchain.indexes._sql_record_manager import SQLRecordManager -from langchain.schema import Document -from langchain.schema.vectorstore import VST, VectorStore class ToyLoader(BaseLoader): diff --git a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py index 8c64574d860..dd79b830348 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py @@ -1,13 +1,14 @@ """Fake Chat Model wrapper for testing purposes.""" from typing import Any, Dict, List, Optional +from langchain_core.schema import ChatGeneration, ChatResult +from langchain_core.schema.messages import AIMessage, BaseMessage + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models.base import SimpleChatModel -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.messages import AIMessage, BaseMessage class FakeChatModel(SimpleChatModel): diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index ad91f08c48b..ba513db5707 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -1,9 +1,10 @@ """Fake LLM wrapper for testing purposes.""" from typing import Any, Dict, List, Mapping, Optional, cast +from langchain_core.pydantic_v1 import validator + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import validator class FakeLLM(LLM): diff --git a/libs/langchain/tests/unit_tests/llms/test_ai21.py b/libs/langchain/tests/unit_tests/llms/test_ai21.py index 87df10ea51b..4c6369e5264 100644 --- a/libs/langchain/tests/unit_tests/llms/test_ai21.py +++ b/libs/langchain/tests/unit_tests/llms/test_ai21.py @@ -1,10 +1,10 @@ """Test AI21 llm""" from typing import cast +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.ai21 import AI21 -from langchain.pydantic_v1 import SecretStr def test_api_key_is_secret_string() -> None: diff --git a/libs/langchain/tests/unit_tests/llms/test_aleph_alpha.py b/libs/langchain/tests/unit_tests/llms/test_aleph_alpha.py index 0d5c787442a..fd564bd4497 100644 --- a/libs/langchain/tests/unit_tests/llms/test_aleph_alpha.py +++ b/libs/langchain/tests/unit_tests/llms/test_aleph_alpha.py @@ -1,10 +1,10 @@ """Test Aleph Alpha specific stuff.""" import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.aleph_alpha import AlephAlpha -from langchain.pydantic_v1 import SecretStr @pytest.mark.requires("aleph_alpha_client") diff --git a/libs/langchain/tests/unit_tests/llms/test_anyscale.py b/libs/langchain/tests/unit_tests/llms/test_anyscale.py index 6805165a656..2aaae85f62c 100644 --- a/libs/langchain/tests/unit_tests/llms/test_anyscale.py +++ b/libs/langchain/tests/unit_tests/llms/test_anyscale.py @@ -1,9 +1,9 @@ """Test Anyscale llm""" import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.anyscale import Anyscale -from langchain.pydantic_v1 import SecretStr @pytest.mark.requires("openai") diff --git a/libs/langchain/tests/unit_tests/llms/test_base.py b/libs/langchain/tests/unit_tests/llms/test_base.py index 56d21b40c8f..cf8f15d28b1 100644 --- a/libs/langchain/tests/unit_tests/llms/test_base.py +++ b/libs/langchain/tests/unit_tests/llms/test_base.py @@ -7,10 +7,10 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base import pytest +from langchain_core.schema import Generation, LLMResult from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.globals import get_llm_cache, set_llm_cache -from langchain.schema import Generation, LLMResult from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/llms/test_callbacks.py b/libs/langchain/tests/unit_tests/llms/test_callbacks.py index 30cc19cd3f8..166c8528186 100644 --- a/libs/langchain/tests/unit_tests/llms/test_callbacks.py +++ b/libs/langchain/tests/unit_tests/llms/test_callbacks.py @@ -1,7 +1,8 @@ """Test LLM callbacks.""" +from langchain_core.schema.messages import HumanMessage + from langchain.chat_models.fake import FakeListChatModel from langchain.llms.fake import FakeListLLM -from langchain.schema.messages import HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import ( FakeCallbackHandler, FakeCallbackHandlerWithChatStart, diff --git a/libs/langchain/tests/unit_tests/llms/test_fireworks.py b/libs/langchain/tests/unit_tests/llms/test_fireworks.py index cdfe6cfe645..26ccc854611 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/unit_tests/llms/test_fireworks.py @@ -2,10 +2,10 @@ import sys import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture from langchain.llms import Fireworks -from langchain.pydantic_v1 import SecretStr if sys.version_info < (3, 9): pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) diff --git a/libs/langchain/tests/unit_tests/llms/test_gooseai.py b/libs/langchain/tests/unit_tests/llms/test_gooseai.py index 626f39a2fa1..7abb2734022 100644 --- a/libs/langchain/tests/unit_tests/llms/test_gooseai.py +++ b/libs/langchain/tests/unit_tests/llms/test_gooseai.py @@ -1,10 +1,10 @@ """Test GooseAI""" import pytest +from langchain_core.pydantic_v1 import SecretStr from pytest import MonkeyPatch from langchain.llms.gooseai import GooseAI -from langchain.pydantic_v1 import SecretStr from langchain.utils.openai import is_openai_v1 diff --git a/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py b/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py index ab53377a772..01562090c6a 100644 --- a/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py +++ b/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py @@ -1,9 +1,9 @@ """Test the Nebula model by Symbl.ai""" +from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.symblai_nebula import Nebula -from langchain.pydantic_v1 import SecretStr def test_api_key_is_secret_string() -> None: diff --git a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr index b935858c8c1..5b6b0e0778f 100644 --- a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -94,7 +94,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -149,7 +149,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "ChatPromptTemplate" @@ -163,7 +163,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "chat", "HumanMessagePromptTemplate" @@ -173,7 +173,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" @@ -233,7 +233,7 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", + "langchain_core", "prompts", "prompt", "PromptTemplate" diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 9404e177d74..72f1135823d 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -3,15 +3,15 @@ from typing import Any, Dict import pytest +from langchain_core.load.dump import dumps +from langchain_core.load.serializable import Serializable +from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.chains.llm import LLMChain from langchain.chat_models.openai import ChatOpenAI from langchain.llms.openai import OpenAI -from langchain.load.dump import dumps -from langchain.load.serializable import Serializable -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.prompts.prompt import PromptTemplate class Person(Serializable): diff --git a/libs/langchain/tests/unit_tests/load/test_load.py b/libs/langchain/tests/unit_tests/load/test_load.py index 69df8af280e..34fcbae7443 100644 --- a/libs/langchain/tests/unit_tests/load/test_load.py +++ b/libs/langchain/tests/unit_tests/load/test_load.py @@ -1,12 +1,12 @@ """Test for Serializable base class""" import pytest +from langchain_core.load.dump import dumpd, dumps +from langchain_core.load.load import load, loads +from langchain_core.prompts.prompt import PromptTemplate from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI -from langchain.load.dump import dumpd, dumps -from langchain.load.load import load, loads -from langchain.prompts.prompt import PromptTemplate class NotSerializable: diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_file.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_file.py index a2351671c4f..19b17a799d1 100644 --- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_file.py +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_file.py @@ -3,9 +3,9 @@ from pathlib import Path from typing import Generator import pytest +from langchain_core.schema.messages import AIMessage, HumanMessage from langchain.memory.chat_message_histories import FileChatMessageHistory -from langchain.schema.messages import AIMessage, HumanMessage @pytest.fixture diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py index a01e1b77dbe..5e451988a7f 100644 --- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import Any, Generator, Tuple import pytest +from langchain_core.schema.messages import AIMessage, HumanMessage from sqlalchemy import Column, Integer, Text from sqlalchemy.orm import DeclarativeBase from langchain.memory.chat_message_histories import SQLChatMessageHistory from langchain.memory.chat_message_histories.sql import DefaultMessageConverter -from langchain.schema.messages import AIMessage, HumanMessage @pytest.fixture() diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py index 5ed50191c1f..c2c88b9c728 100644 --- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py @@ -6,7 +6,7 @@ test_script = """ import streamlit as st from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import StreamlitChatMessageHistory - from langchain.schema.messages import _message_to_dict + from langchain_core.schema.messages import _message_to_dict message_history = StreamlitChatMessageHistory() memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True) diff --git a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py index b056e53c4ec..fcd2131060b 100644 --- a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py +++ b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py @@ -1,5 +1,5 @@ """Test for CombinedMemory class""" -# from langchain.prompts import PromptTemplate +# from langchain_core.prompts import PromptTemplate from typing import List import pytest diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py index f1992b40892..ee35e85bbbd 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py @@ -1,7 +1,8 @@ from enum import Enum +from langchain_core.schema import OutputParserException + from langchain.output_parsers.enum import EnumOutputParser -from langchain.schema import OutputParserException class Colors(Enum): diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 6fe7cb27dd1..c7b4792e06c 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -2,6 +2,7 @@ import json from typing import Any, AsyncIterator, Iterator, Tuple import pytest +from langchain_core.schema.messages import AIMessageChunk from langchain.output_parsers.json import ( SimpleJsonOutputParser, @@ -9,7 +10,6 @@ from langchain.output_parsers.json import ( parse_partial_json, ) from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser -from langchain.schema.messages import AIMessageChunk GOOD_JSON = """```json { diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py index 364f5a5a22e..a86696603dc 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py @@ -1,12 +1,12 @@ from typing import Any, Dict import pytest +from langchain_core.schema import BaseMessage, ChatGeneration, OutputParserException +from langchain_core.schema.messages import AIMessage, HumanMessage from langchain.output_parsers.openai_functions import ( JsonOutputFunctionsParser, ) -from langchain.schema import BaseMessage, ChatGeneration, OutputParserException -from langchain.schema.messages import AIMessage, HumanMessage def test_json_output_function_parser() -> None: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py index 2e75860f19b..b7b693256cd 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -2,9 +2,10 @@ from enum import Enum from typing import Optional +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.schema import OutputParserException + from langchain.output_parsers.pydantic import PydanticOutputParser -from langchain.pydantic_v1 import BaseModel, Field -from langchain.schema import OutputParserException class Actions(Enum): diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py index abf19a7c097..9a59a6ea9e1 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py @@ -1,5 +1,6 @@ +from langchain_core.schema import OutputParserException + from langchain.output_parsers import ResponseSchema, StructuredOutputParser -from langchain.schema import OutputParserException def test_parse() -> None: diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py index cb9909c9015..b0eea6dfdf7 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Tuple, Union import pytest +from langchain_core.schema import Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -16,7 +17,6 @@ from langchain.chains.query_constructor.ir import ( ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers import SelfQueryRetriever -from langchain.schema import Document from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py index b75913a9617..41830958b08 100644 --- a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py @@ -1,6 +1,6 @@ from typing import List -from langchain.schema import BaseRetriever, Document +from langchain_core.schema import BaseRetriever, Document class SequentialRetriever(BaseRetriever): diff --git a/libs/langchain/tests/unit_tests/retrievers/test_base.py b/libs/langchain/tests/unit_tests/retrievers/test_base.py index 45237f50d0c..90b83c12d52 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_base.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_base.py @@ -5,12 +5,12 @@ from __future__ import annotations from typing import Dict, List, Optional import pytest +from langchain_core.schema import BaseRetriever, Document from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.schema import BaseRetriever, Document from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/langchain/tests/unit_tests/retrievers/test_bm25.py b/libs/langchain/tests/unit_tests/retrievers/test_bm25.py index f021d708e7e..1fd1512e8f5 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_bm25.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_bm25.py @@ -1,7 +1,7 @@ import pytest +from langchain_core.schema import Document from langchain.retrievers.bm25 import BM25Retriever -from langchain.schema import Document @pytest.mark.requires("rank_bm25") diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index 2488ff0643b..231dbf035b7 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,8 +1,8 @@ import pytest +from langchain_core.schema import Document from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever -from langchain.schema import Document @pytest.mark.requires("rank_bm25") diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py index 978950ec58a..59ba4463f5f 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -1,9 +1,9 @@ from typing import List import pytest as pytest +from langchain_core.schema import Document from langchain.retrievers.multi_query import _unique_documents -from langchain.schema import Document @pytest.mark.parametrize( diff --git a/libs/langchain/tests/unit_tests/retrievers/test_remote_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_remote_retriever.py index acae7bdff9f..a77abf8fc17 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_remote_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_remote_retriever.py @@ -1,9 +1,9 @@ from typing import Any, Dict +from langchain_core.schema import Document from pytest_mock import MockerFixture from langchain.retrievers import RemoteLangChainRetriever -from langchain.schema import Document class MockResponse: diff --git a/libs/langchain/tests/unit_tests/retrievers/test_svm.py b/libs/langchain/tests/unit_tests/retrievers/test_svm.py index 491be75a6cf..9379648b254 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_svm.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_svm.py @@ -1,8 +1,8 @@ import pytest +from langchain_core.schema import Document from langchain.embeddings import FakeEmbeddings from langchain.retrievers.svm import SVMRetriever -from langchain.schema import Document class TestSVMRetriever: diff --git a/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py b/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py index 484fa6f0c69..34a78caad46 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py @@ -3,9 +3,9 @@ from datetime import datetime from tempfile import TemporaryDirectory import pytest +from langchain_core.schema import Document from langchain.retrievers.tfidf import TFIDFRetriever -from langchain.schema import Document @pytest.mark.requires("sklearn") diff --git a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py index abe220441aa..4fc2d42a132 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -4,14 +4,14 @@ from datetime import datetime, timedelta from typing import Any, Iterable, List, Optional, Tuple, Type import pytest +from langchain_core.schema import Document +from langchain_core.schema.embeddings import Embeddings +from langchain_core.schema.vectorstore import VectorStore from langchain.retrievers.time_weighted_retriever import ( TimeWeightedVectorStoreRetriever, _get_hours_passed, ) -from langchain.schema import Document -from langchain.schema.embeddings import Embeddings -from langchain.schema.vectorstore import VectorStore def _get_example_memories(k: int = 4) -> List[Document]: diff --git a/libs/langchain/tests/unit_tests/retrievers/test_you.py b/libs/langchain/tests/unit_tests/retrievers/test_you.py index 54ab1cb8e2b..703b5ddf9cd 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_you.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_you.py @@ -2,10 +2,10 @@ import json import os from unittest import mock +from langchain_core.schema import Document from requests import Response from langchain.retrievers.you import YouRetriever -from langchain.schema import Document class TestYouRetriever: diff --git a/libs/langchain/tests/unit_tests/runnables/test_hub.py b/libs/langchain/tests/unit_tests/runnables/test_hub.py index 7294006fa38..f17700ef2c6 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_hub.py +++ b/libs/langchain/tests/unit_tests/runnables/test_hub.py @@ -1,9 +1,10 @@ from typing import Any from unittest.mock import Mock, patch -from langchain.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables.base import ConfigurableField + from langchain.runnables.hub import HubRunnable -from langchain.schema.runnable.base import ConfigurableField @patch("langchain.hub.pull") @@ -36,7 +37,7 @@ def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate: def test_hub_runnable_configurable_alternative(mock_pull: Mock) -> None: mock_pull.side_effect = repo_lookup - original: HubRunnable[Any, Any] = HubRunnable("efriis/my-prompt-1") + original: HubRunnable = HubRunnable("efriis/my-prompt-1") obj_a1 = original.configurable_alternatives( ConfigurableField(id="owner_repo_commit", name="Hub ID"), default_key="a1", @@ -58,7 +59,7 @@ def test_hub_runnable_configurable_alternative(mock_pull: Mock) -> None: def test_hub_runnable_configurable_fields(mock_pull: Mock) -> None: mock_pull.side_effect = repo_lookup - original: HubRunnable[Any, Any] = HubRunnable("efriis/my-prompt-1") + original: HubRunnable = HubRunnable("efriis/my-prompt-1") obj_configurable = original.configurable_fields( owner_repo_commit=ConfigurableField(id="owner_repo_commit", name="Hub ID"), ) diff --git a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py index e4cec167d88..7c78fd36024 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py @@ -1,14 +1,14 @@ from typing import Any, List, Optional +from langchain_core.schema import ChatResult +from langchain_core.schema.messages import AIMessage, BaseMessage +from langchain_core.schema.output import ChatGeneration from pytest_mock import MockerFixture from syrupy import SnapshotAssertion from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel from langchain.runnables.openai_functions import OpenAIFunctionsRouter -from langchain.schema import ChatResult -from langchain.schema.messages import AIMessage, BaseMessage -from langchain.schema.output import ChatGeneration class FakeChatOpenAI(BaseChatModel): diff --git a/libs/langchain/tests/unit_tests/schema/test_imports.py b/libs/langchain/tests/unit_tests/schema/test_imports.py index 93c7dff3c51..5bc2f228798 100644 --- a/libs/langchain/tests/unit_tests/schema/test_imports.py +++ b/libs/langchain/tests/unit_tests/schema/test_imports.py @@ -1,4 +1,4 @@ -from langchain.schema import __all__ +from langchain_core.schema import __all__ EXPECTED_ALL = [ "BaseCache", diff --git a/libs/langchain/tests/unit_tests/schema/test_messages.py b/libs/langchain/tests/unit_tests/schema/test_messages.py index b72c10010df..6cfe6d2649b 100644 --- a/libs/langchain/tests/unit_tests/schema/test_messages.py +++ b/libs/langchain/tests/unit_tests/schema/test_messages.py @@ -1,6 +1,5 @@ import pytest - -from langchain.schema.messages import ( +from langchain_core.schema.messages import ( AIMessageChunk, ChatMessageChunk, FunctionMessageChunk, diff --git a/libs/langchain/tests/unit_tests/schema/test_output.py b/libs/langchain/tests/unit_tests/schema/test_output.py index 84ec53cdddd..5e086c5e5a3 100644 --- a/libs/langchain/tests/unit_tests/schema/test_output.py +++ b/libs/langchain/tests/unit_tests/schema/test_output.py @@ -1,5 +1,5 @@ -from langchain.schema.messages import HumanMessageChunk -from langchain.schema.output import ChatGenerationChunk, GenerationChunk +from langchain_core.schema.messages import HumanMessageChunk +from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk def test_generation_chunk() -> None: diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py index 3d3a6b1a401..061cf0e740c 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py @@ -6,12 +6,12 @@ from unittest import mock import pytest from freezegun import freeze_time +from langchain_core.schema.language_model import BaseLanguageModel from langsmith.client import Client from langsmith.schemas import Dataset, Example from langchain.chains.base import Chain from langchain.chains.transform import TransformChain -from langchain.schema.language_model import BaseLanguageModel from langchain.smith.evaluation.runner_utils import ( InputFormatError, _get_messages, diff --git a/libs/langchain/tests/unit_tests/storage/test_lc_store.py b/libs/langchain/tests/unit_tests/storage/test_lc_store.py index 5d15683ac31..5b1eff98c82 100644 --- a/libs/langchain/tests/unit_tests/storage/test_lc_store.py +++ b/libs/langchain/tests/unit_tests/storage/test_lc_store.py @@ -2,8 +2,8 @@ import tempfile from typing import Generator, cast import pytest +from langchain_core.schema import Document -from langchain.schema import Document from langchain.storage._lc_store import create_kv_docstore, create_lc_store from langchain.storage.file_system import LocalFileStore diff --git a/libs/langchain/tests/unit_tests/test_cache.py b/libs/langchain/tests/unit_tests/test_cache.py index b413fa41c16..e9751747658 100644 --- a/libs/langchain/tests/unit_tests/test_cache.py +++ b/libs/langchain/tests/unit_tests/test_cache.py @@ -3,6 +3,12 @@ from typing import Dict, Generator, List, Union import pytest from _pytest.fixtures import FixtureRequest +from langchain_core.load import dumps +from langchain_core.schema import ( + ChatGeneration, + Generation, +) +from langchain_core.schema.messages import AIMessage, BaseMessage, HumanMessage from sqlalchemy import create_engine from sqlalchemy.orm import Session @@ -11,15 +17,10 @@ from langchain.cache import ( SQLAlchemyCache, ) from langchain.chat_models import FakeListChatModel -from langchain.chat_models.base import BaseChatModel, dumps +from langchain.chat_models.base import BaseChatModel from langchain.globals import get_llm_cache, set_llm_cache from langchain.llms import FakeListLLM from langchain.llms.base import BaseLLM -from langchain.schema import ( - ChatGeneration, - Generation, -) -from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage def get_sqlite_cache() -> SQLAlchemyCache: diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index 8db357cf7bf..e23e3fb3c70 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -91,6 +91,8 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: def test_imports() -> None: """Test that you can import all top level things okay.""" + from langchain_core.schema import BasePromptTemplate # noqa: F401 + from langchain.agents import OpenAIFunctionsAgent # noqa: F401 from langchain.callbacks import OpenAICallbackHandler # noqa: F401 from langchain.chains import LLMChain # noqa: F401 @@ -99,7 +101,6 @@ def test_imports() -> None: from langchain.embeddings import OpenAIEmbeddings # noqa: F401 from langchain.llms import OpenAI # noqa: F401 from langchain.retrievers import VespaRetriever # noqa: F401 - from langchain.schema import BasePromptTemplate # noqa: F401 from langchain.tools import DuckDuckGoSearchResults # noqa: F401 from langchain.utilities import SerpAPIWrapper # noqa: F401 from langchain.vectorstores import FAISS # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/test_formatting.py b/libs/langchain/tests/unit_tests/test_formatting.py index ef941932da2..096fd13d306 100644 --- a/libs/langchain/tests/unit_tests/test_formatting.py +++ b/libs/langchain/tests/unit_tests/test_formatting.py @@ -1,7 +1,6 @@ """Test formatting functionality.""" import pytest - -from langchain.utils import formatter +from langchain_core.utils import formatter def test_valid_formatting() -> None: diff --git a/libs/langchain/tests/unit_tests/test_globals.py b/libs/langchain/tests/unit_tests/test_globals.py index 8249d39d65d..5df209c165c 100644 --- a/libs/langchain/tests/unit_tests/test_globals.py +++ b/libs/langchain/tests/unit_tests/test_globals.py @@ -2,8 +2,9 @@ from langchain.globals import get_debug, get_verbose, set_debug, set_verbose def test_debug_is_settable_directly() -> None: + from langchain_core.callbacks.manager import _get_debug + import langchain - from langchain.schema.callbacks.manager import _get_debug previous_value = langchain.debug previous_fn_reading = _get_debug() @@ -32,8 +33,9 @@ def test_debug_is_settable_directly() -> None: def test_debug_is_settable_via_setter() -> None: + from langchain_core.callbacks.manager import _get_debug + from langchain import globals - from langchain.schema.callbacks.manager import _get_debug previous_value = globals._debug previous_fn_reading = _get_debug() diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index 4ee90e03baa..7f833236339 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -3,19 +3,18 @@ import unittest from typing import Union import pytest - -from langchain.prompts.base import StringPromptValue -from langchain.prompts.chat import ChatPromptValueConcrete -from langchain.pydantic_v1 import BaseModel, ValidationError -from langchain.schema import ( +from langchain_core.prompts.base import StringPromptValue +from langchain_core.prompts.chat import ChatPromptValueConcrete +from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.schema import ( AgentAction, AgentFinish, ChatGeneration, Document, Generation, ) -from langchain.schema.agent import AgentActionMessageLog -from langchain.schema.messages import ( +from langchain_core.schema.agent import AgentActionMessageLog +from langchain_core.schema.messages import ( AIMessage, AIMessageChunk, ChatMessage, @@ -30,7 +29,7 @@ from langchain.schema.messages import ( messages_from_dict, messages_to_dict, ) -from langchain.schema.output import ChatGenerationChunk +from langchain_core.schema.output import ChatGenerationChunk class TestGetBufferString(unittest.TestCase): diff --git a/libs/langchain/tests/unit_tests/test_utils.py b/libs/langchain/tests/unit_tests/test_utils.py index 525cebef4da..d94f2f276e7 100644 --- a/libs/langchain/tests/unit_tests/test_utils.py +++ b/libs/langchain/tests/unit_tests/test_utils.py @@ -1,6 +1,5 @@ import pytest - -from langchain.utils import check_package_version +from langchain_core.utils import check_package_version def test_check_package_version_pass() -> None: diff --git a/libs/langchain/tests/unit_tests/tools/openapi/test_api_models.py b/libs/langchain/tests/unit_tests/tools/openapi/test_api_models.py index 9aa415625ff..70614822174 100644 --- a/libs/langchain/tests/unit_tests/tools/openapi/test_api_models.py +++ b/libs/langchain/tests/unit_tests/tools/openapi/test_api_models.py @@ -8,7 +8,7 @@ import pytest # Keep at top of file to ensure that pydantic test can be skipped before # pydantic v1 related imports are attempted by openapi_pydantic. -from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION +from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION if _PYDANTIC_MAJOR_VERSION != 1: pytest.skip( diff --git a/libs/langchain/tests/unit_tests/tools/test_exported.py b/libs/langchain/tests/unit_tests/tools/test_exported.py index 1e6c0f19537..78eb7e50461 100644 --- a/libs/langchain/tests/unit_tests/tools/test_exported.py +++ b/libs/langchain/tests/unit_tests/tools/test_exported.py @@ -21,9 +21,11 @@ def _get_tool_classes(skip_tools_without_default_names: bool) -> List[Type[BaseT if isinstance(tool_class, type) and issubclass(tool_class, BaseTool): if tool_class in _EXCLUDE: continue - if skip_tools_without_default_names and tool_class.__fields__[ - "name" - ].default in [None, ""]: + if ( + skip_tools_without_default_names + and tool_class.__fields__["name"].default # type: ignore + in [None, ""] + ): continue results.append(tool_class) return results diff --git a/libs/langchain/tests/unit_tests/utilities/test_loading.py b/libs/langchain/tests/unit_tests/utilities/test_loading.py index c74df087d14..961bcbd8733 100644 --- a/libs/langchain/tests/unit_tests/utilities/test_loading.py +++ b/libs/langchain/tests/unit_tests/utilities/test_loading.py @@ -9,8 +9,7 @@ from urllib.parse import urljoin import pytest import responses - -from langchain.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub +from langchain_core.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub @pytest.fixture(autouse=True) diff --git a/libs/langchain/tests/unit_tests/utils/test_iter.py b/libs/langchain/tests/unit_tests/utils/test_iter.py index f0fd8bf4ce5..01a400f9d37 100644 --- a/libs/langchain/tests/unit_tests/utils/test_iter.py +++ b/libs/langchain/tests/unit_tests/utils/test_iter.py @@ -1,8 +1,7 @@ from typing import List import pytest - -from langchain.utils.iter import batch_iterate +from langchain_core.utils.iter import batch_iterate @pytest.mark.parametrize( diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py index b5a22d837b9..804fafa7d06 100644 --- a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -1,4 +1,5 @@ -from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain.utils.openai_functions import convert_pydantic_to_openai_function diff --git a/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py b/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py index f6ec2e86d2a..c0f19629b3e 100644 --- a/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py +++ b/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py @@ -36,7 +36,7 @@ def test_numeric_field_schema_creation() -> None: def test_redis_vector_field_validation() -> None: """Test validation for RedisVectorField's datatype.""" - from langchain.pydantic_v1 import ValidationError + from langchain_core.pydantic_v1 import ValidationError with pytest.raises(ValidationError): RedisVectorField( diff --git a/libs/langchain/tests/unit_tests/vectorstores/test_imports.py b/libs/langchain/tests/unit_tests/vectorstores/test_imports.py index 16a5dc0b87a..633dbad9957 100644 --- a/libs/langchain/tests/unit_tests/vectorstores/test_imports.py +++ b/libs/langchain/tests/unit_tests/vectorstores/test_imports.py @@ -1,5 +1,6 @@ +from langchain_core.schema.vectorstore import VectorStore + from langchain import vectorstores -from langchain.schema.vectorstore import VectorStore def test_all_imports() -> None: